diff --git a/.gitignore b/.gitignore index 14c6bac..dafc1ed 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,10 @@ +.idea/ +*.iml +*.ipr dist/ target/ lib_managed/ src_managed/ project/boot/ project/plugins/project/ +querulous.tmproj diff --git a/README.markdown b/README.markdown index 91e09a1..cde889c 100644 --- a/README.markdown +++ b/README.markdown @@ -154,7 +154,7 @@ Add the following dependency and repository stanzas to your project's configurat twitter.com - http://www.lag.net/nest + http://maven.twttr.com/ ### Ivy @@ -165,7 +165,7 @@ Add the following dependency to ivy.xml and the following repository to ivysettings.xml - + ## Running Tests diff --git a/config/test.conf b/config/test.conf deleted file mode 100644 index 6483d9a..0000000 --- a/config/test.conf +++ /dev/null @@ -1,9 +0,0 @@ -db { - hostname = "localhost" - username = "root" - password = "" - url_options { - useUnicode = "true" - characterEncoding = "UTF-8" - } -} diff --git a/config/test.scala b/config/test.scala new file mode 100644 index 0000000..40b5c34 --- /dev/null +++ b/config/test.scala @@ -0,0 +1,15 @@ +import com.twitter.querulous.config.Connection + +new Connection { + val hostnames = Seq("localhost") + val database = "db_test" + val username = { + val userEnv = System.getenv("DB_USERNAME") + if (userEnv == null) "root" else userEnv + } + + val password = { + val passEnv = System.getenv("DB_PASSWORD") + if (passEnv == null) "" else passEnv + } +} diff --git a/project/build.properties b/project/build.properties index 2da0f90..7e84939 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1,8 +1,8 @@ #Project properties -#Fri Sep 17 11:09:45 PDT 2010 +#Mon Feb 28 13:01:14 PST 2011 project.organization=com.twitter project.name=querulous sbt.version=0.7.4 -project.version=1.2.2 -build.scala.versions=2.7.7 +project.version=2.0.2-SNAPSHOT +build.scala.versions=2.8.1 project.initialize=false diff --git a/project/build/QuerulousProject.scala b/project/build/QuerulousProject.scala index 3764ce3..ee9eaee 100644 --- a/project/build/QuerulousProject.scala +++ b/project/build/QuerulousProject.scala @@ -1,30 +1,23 @@ import sbt._ +import Process._ import com.twitter.sbt._ - class QuerulousProject(info: ProjectInfo) extends StandardProject(info) with SubversionPublisher { - val vscaladoc = "org.scala-tools" % "vscaladoc" % "1.1-md-3" - val configgy = "net.lag" % "configgy" % "1.5.2" + override def filterScalaJars = false + + val util = "com.twitter" % "util" % "1.6.4" + val dbcp = "commons-dbcp" % "commons-dbcp" % "1.4" - val mysqljdbc = "mysql" % "mysql-connector-java" % "5.1.6" + val mysqljdbc = "mysql" % "mysql-connector-java" % "5.1.13" val pool = "commons-pool" % "commons-pool" % "1.5.4" - val xrayspecs = "com.twitter" % "xrayspecs" % "1.0.7" - - val hamcrest = "org.hamcrest" % "hamcrest-all" % "1.1" % "test" - val specs = "org.scala-tools.testing" % "specs" % "1.6.2.1" % "test" - val objenesis = "org.objenesis" % "objenesis" % "1.1" % "test" - val jmock = "org.jmock" % "jmock" % "2.4.0" % "test" - val cglib = "cglib" % "cglib" % "2.1_3" % "test" - val asm = "asm" % "asm" % "1.5.3" % "test" - override def pomExtra = - - - Apache 2 - http://www.apache.org/licenses/LICENSE-2.0.txt - repo - - + val scalaTools = "org.scala-lang" % "scala-compiler" % "2.8.1" % "test" + val hamcrest = "org.hamcrest" % "hamcrest-all" % "1.1" % "test" + val specs = "org.scala-tools.testing" % "specs_2.8.0" % "1.6.5" % "test" + val objenesis = "org.objenesis" % "objenesis" % "1.1" % "test" + val jmock = "org.jmock" % "jmock" % "2.4.0" % "test" + val cglib = "cglib" % "cglib" % "2.1_3" % "test" + val asm = "asm" % "asm" % "1.5.3" % "test" override def subversionRepository = Some("http://svn.local.twitter.com/maven-public/") } diff --git a/project/plugins/Plugins.scala b/project/plugins/Plugins.scala index 9bb719d..7e3ef44 100644 --- a/project/plugins/Plugins.scala +++ b/project/plugins/Plugins.scala @@ -1,6 +1,6 @@ import sbt._ class Plugins(info: ProjectInfo) extends PluginDefinition(info) { - val twitterNest = "com.twitter" at "http://www.lag.net/nest" - val defaultProject = "com.twitter" % "standard-project" % "0.7.1" + val twitter = "twitter.com" at "http://maven.twttr.com/" + val defaultProject = "com.twitter" % "standard-project" % "0.7.17" } diff --git a/src/main/scala/com/twitter/querulous/AutoDisabler.scala b/src/main/scala/com/twitter/querulous/AutoDisabler.scala index aefff81..f0d23f7 100644 --- a/src/main/scala/com/twitter/querulous/AutoDisabler.scala +++ b/src/main/scala/com/twitter/querulous/AutoDisabler.scala @@ -1,7 +1,7 @@ package com.twitter.querulous -import com.twitter.xrayspecs.{Time, Duration} -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.util.{Time, Duration} +import com.twitter.util.TimeConversions._ import java.sql.{SQLException, SQLIntegrityConstraintViolationException} @@ -9,7 +9,7 @@ trait AutoDisabler { protected val disableErrorCount: Int protected val disableDuration: Duration - private var disabledUntil: Time = Time.never + private var disabledUntil: Time = Time.epoch private var consecutiveErrors = 0 protected def throwIfDisabled(throwMessage: String): Unit = { diff --git a/src/main/scala/com/twitter/querulous/ConnectionDestroying.scala b/src/main/scala/com/twitter/querulous/ConnectionDestroying.scala index 907507e..09f010b 100644 --- a/src/main/scala/com/twitter/querulous/ConnectionDestroying.scala +++ b/src/main/scala/com/twitter/querulous/ConnectionDestroying.scala @@ -4,12 +4,10 @@ import java.sql.Connection import org.apache.commons.dbcp.{DelegatingConnection => DBCPConnection} import com.mysql.jdbc.{ConnectionImpl => MySQLConnection} - // Emergency connection destruction toolkit - trait ConnectionDestroying { def destroyConnection(conn: Connection) { - if ( !conn.isClosed ) + if (!conn.isClosed) conn match { case c: DBCPConnection => destroyDbcpWrappedConnection(c) @@ -22,21 +20,18 @@ trait ConnectionDestroying { def destroyDbcpWrappedConnection(conn: DBCPConnection) { val inner = conn.getInnermostDelegate - if ( inner != null ) { + if (inner ne null) { destroyConnection(inner) } else { - // this should never happen if we use our own ApachePoolingDatabase to get connections. - error("Could not get access to the delegate connection. Make sure the dbcp connection pool allows access to underlying connections.") + // might just be a race; move on. + return } // "close" the wrapper so that it updates its internal bookkeeping, just do it - try { conn.close } catch { case _ => } + try { conn.close() } catch { case _ => } } def destroyMysqlConnection(conn: MySQLConnection) { - val abort = Class.forName("com.mysql.jdbc.ConnectionImpl").getDeclaredMethod("abortInternal") - abort.setAccessible(true) - - abort.invoke(conn) + conn.abortInternal() } } diff --git a/src/main/scala/com/twitter/querulous/FutureTimeout.scala b/src/main/scala/com/twitter/querulous/FutureTimeout.scala index 1975742..931f1de 100644 --- a/src/main/scala/com/twitter/querulous/FutureTimeout.scala +++ b/src/main/scala/com/twitter/querulous/FutureTimeout.scala @@ -2,7 +2,7 @@ package com.twitter.querulous import java.util.concurrent.{ThreadFactory, TimeoutException => JTimeoutException, _} import java.util.concurrent.atomic.AtomicInteger -import com.twitter.xrayspecs.Duration +import com.twitter.util.Duration class FutureTimeout(poolSize: Int, queueSize: Int) { object DaemonThreadFactory extends ThreadFactory { @@ -20,7 +20,7 @@ class FutureTimeout(poolSize: Int, queueSize: Int) { thread } } - private val executor = new ThreadPoolExecutor(poolSize, poolSize, 0, TimeUnit.SECONDS, + private val executor = new ThreadPoolExecutor(1, poolSize, 60, TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable](queueSize), DaemonThreadFactory) diff --git a/src/main/scala/com/twitter/querulous/StatsCollector.scala b/src/main/scala/com/twitter/querulous/StatsCollector.scala index 7725972..41ccd0a 100644 --- a/src/main/scala/com/twitter/querulous/StatsCollector.scala +++ b/src/main/scala/com/twitter/querulous/StatsCollector.scala @@ -4,3 +4,8 @@ trait StatsCollector { def incr(name: String, count: Int) def time[A](name: String)(f: => A): A } + +object NullStatsCollector extends StatsCollector { + def incr(name: String, count: Int) {} + def time[A](name: String)(f: => A): A = f +} diff --git a/src/main/scala/com/twitter/querulous/Timeout.scala b/src/main/scala/com/twitter/querulous/Timeout.scala index 8167d82..365a583 100644 --- a/src/main/scala/com/twitter/querulous/Timeout.scala +++ b/src/main/scala/com/twitter/querulous/Timeout.scala @@ -1,7 +1,7 @@ package com.twitter.querulous import java.util.{Timer, TimerTask} -import com.twitter.xrayspecs.Duration +import com.twitter.util.Duration class TimeoutException extends Exception @@ -20,7 +20,9 @@ object Timeout { } finally { task map { t => t.cancel() - timer.purge() + // TODO(benjy): Timer is not optimized to deal with large numbers of cancellations: it releases and reacquires its monitor + // on every task, cancelled or not, when it could quickly skip over all cancelled tasks in a single monitor region. + // This may not be a problem, but it's something to be aware of. } if (cancelled) throw new TimeoutException } @@ -32,7 +34,15 @@ object Timeout { private def schedule(timer: Timer, timeout: Duration, f: => Unit) = { val task = new TimerTask() { - override def run() { f } + override def run() { + try { + f + } catch { + case e: Throwable => + error("Timer task tried to throw an exception: " + e.toString()) + e.printStackTrace(System.err) + } + } } timer.schedule(task, timeout.inMillis) task diff --git a/src/main/scala/com/twitter/querulous/config/Database.scala b/src/main/scala/com/twitter/querulous/config/Database.scala new file mode 100644 index 0000000..cd8c650 --- /dev/null +++ b/src/main/scala/com/twitter/querulous/config/Database.scala @@ -0,0 +1,125 @@ +package com.twitter.querulous.config + +import com.twitter.querulous._ +import com.twitter.util.Duration +import com.twitter.util.TimeConversions._ +import database._ + + +class ApachePoolingDatabase { + var sizeMin: Int = 10 + var sizeMax: Int = 10 + var testIdle: Duration = 1.second + var maxWait: Duration = 10.millis + var minEvictableIdle: Duration = 60.seconds + var testOnBorrow: Boolean = false +} + +class TimingOutDatabase { + var poolSize: Int = 10 + var queueSize: Int = 10000 + var open: Duration = 1.second +} + +trait AutoDisablingDatabase { + def errorCount: Int + def interval: Duration +} + +class Database { + var pool: Option[ApachePoolingDatabase] = None + def pool_=(p: ApachePoolingDatabase) { pool = Some(p) } + var autoDisable: Option[AutoDisablingDatabase] = None + def autoDisable_=(a: AutoDisablingDatabase) { autoDisable = Some(a) } + var timeout: Option[TimingOutDatabase] = None + def timeout_=(t: TimingOutDatabase) { timeout = Some(t) } + var memoize: Boolean = true + + def apply(stats: StatsCollector): DatabaseFactory = { + var factory: DatabaseFactory = pool.map(apacheConfig => + new ApachePoolingDatabaseFactory( + apacheConfig.sizeMin, + apacheConfig.sizeMax, + apacheConfig.testIdle, + apacheConfig.maxWait, + apacheConfig.testOnBorrow, + apacheConfig.minEvictableIdle) + ).getOrElse(new SingleConnectionDatabaseFactory) + + timeout.foreach { timeoutConfig => + factory = new TimingOutDatabaseFactory(factory, + timeoutConfig.poolSize, + timeoutConfig.queueSize, + timeoutConfig.open, + timeoutConfig.poolSize) + } + + if (stats ne NullStatsCollector) { + factory = new StatsCollectingDatabaseFactory(factory, stats) + } + + autoDisable.foreach { disable => + factory = new AutoDisablingDatabaseFactory(factory, disable.errorCount, disable.interval) + } + + if (memoize) { + factory = new MemoizingDatabaseFactory(factory) + } + + factory + } + + def apply(): DatabaseFactory = apply(NullStatsCollector) +} + +trait Connection { + def hostnames: Seq[String] + def database: String + def username: String + def password: String + var urlOptions: Map[String, String] = Map() + + def withHost(newHost: String) = { + val current = this + new Connection { + def hostnames = Seq(newHost) + def database = current.database + def username = current.username + def password = current.password + urlOptions = current.urlOptions + } + } + + def withHosts(newHosts: Seq[String]) = { + val current = this + new Connection { + def hostnames = newHosts + def database = current.database + def username = current.username + def password = current.password + urlOptions = current.urlOptions + } + } + + def withDatabase(newDatabase: String) = { + val current = this + new Connection { + def hostnames = current.hostnames + def database = newDatabase + def username = current.username + def password = current.password + urlOptions = current.urlOptions + } + } + + def withoutDatabase = { + val current = this + new Connection { + def hostnames = current.hostnames + def database = null + def username = current.username + def password = current.password + urlOptions = current.urlOptions + } + } +} diff --git a/src/main/scala/com/twitter/querulous/config/Query.scala b/src/main/scala/com/twitter/querulous/config/Query.scala new file mode 100644 index 0000000..2467a17 --- /dev/null +++ b/src/main/scala/com/twitter/querulous/config/Query.scala @@ -0,0 +1,59 @@ +package com.twitter.querulous.config + +import com.twitter.querulous._ +import com.twitter.util.Duration +import com.twitter.util.TimeConversions._ +import query._ + + +object QueryTimeout { + def apply(timeout: Duration, cancelOnTimeout: Boolean) = + new QueryTimeout(timeout, cancelOnTimeout) + + def apply(timeout: Duration) = + new QueryTimeout(timeout, false) +} + +class QueryTimeout(val timeout: Duration, val cancelOnTimeout: Boolean) + +object NoDebugOutput extends (String => Unit) { + def apply(s: String) = () +} + +class Query { + var timeouts: Map[QueryClass, QueryTimeout] = Map( + QueryClass.Select -> QueryTimeout(5.seconds), + QueryClass.Execute -> QueryTimeout(5.seconds) + ) + + var retries: Int = 0 + var debug: (String => Unit) = NoDebugOutput + + def apply(statsCollector: StatsCollector): QueryFactory = { + var queryFactory: QueryFactory = new SqlQueryFactory + + if (!timeouts.isEmpty) { + val tupleTimeout = Map(timeouts.map { case (queryClass, timeout) => + (queryClass, (timeout.timeout, timeout.cancelOnTimeout)) + }.toList: _*) + + queryFactory = new PerQueryTimingOutQueryFactory(new SqlQueryFactory, tupleTimeout) + } + + if (statsCollector ne NullStatsCollector) { + queryFactory = new StatsCollectingQueryFactory(queryFactory, statsCollector) + } + + if (retries > 0) { + queryFactory = new RetryingQueryFactory(queryFactory, retries) + } + + if (debug ne NoDebugOutput) { + queryFactory = new DebuggingQueryFactory(queryFactory, debug) + } + + queryFactory + } + + def apply(): QueryFactory = apply(NullStatsCollector) +} diff --git a/src/main/scala/com/twitter/querulous/config/QueryEvaluator.scala b/src/main/scala/com/twitter/querulous/config/QueryEvaluator.scala new file mode 100644 index 0000000..9fc09ff --- /dev/null +++ b/src/main/scala/com/twitter/querulous/config/QueryEvaluator.scala @@ -0,0 +1,33 @@ +package com.twitter.querulous.config + +import com.twitter.querulous._ +import com.twitter.util.Duration +import evaluator._ + +trait AutoDisablingQueryEvaluator { + def errorCount: Int + def interval: Duration +} + +class QueryEvaluator { + var database: Database = new Database + var query: Query = new Query + + var autoDisable: Option[AutoDisablingQueryEvaluator] = None + def autoDisable_=(a: AutoDisablingQueryEvaluator) { autoDisable = Some(a) } + + + def apply(stats: StatsCollector): QueryEvaluatorFactory = { + var factory: QueryEvaluatorFactory = + new StandardQueryEvaluatorFactory(database(stats), query(stats)) + + autoDisable.foreach { disable => + factory = new AutoDisablingQueryEvaluatorFactory( + factory, disable.errorCount, disable.interval + ) + } + factory + } + + def apply(): QueryEvaluatorFactory = apply(NullStatsCollector) +} diff --git a/src/main/scala/com/twitter/querulous/database/ApachePoolingDatabase.scala b/src/main/scala/com/twitter/querulous/database/ApachePoolingDatabase.scala index d2c590f..5ab6213 100644 --- a/src/main/scala/com/twitter/querulous/database/ApachePoolingDatabase.scala +++ b/src/main/scala/com/twitter/querulous/database/ApachePoolingDatabase.scala @@ -3,30 +3,42 @@ package com.twitter.querulous.database import java.sql.{SQLException, Connection} import org.apache.commons.dbcp.{PoolableConnectionFactory, DriverManagerConnectionFactory, PoolingDataSource} import org.apache.commons.pool.impl.{GenericObjectPool, StackKeyedObjectPoolFactory} -import com.twitter.xrayspecs.Duration +import com.twitter.util.Duration class ApachePoolingDatabaseFactory( - minOpenConnections: Int, - maxOpenConnections: Int, + val minOpenConnections: Int, + val maxOpenConnections: Int, checkConnectionHealthWhenIdleFor: Duration, maxWaitForConnectionReservation: Duration, checkConnectionHealthOnReservation: Boolean, - evictConnectionIfIdleFor: Duration) extends DatabaseFactory { + evictConnectionIfIdleFor: Duration, + defaultUrlOptions: Map[String, String]) extends DatabaseFactory { + + def this(minConns: Int, maxConns: Int, checkIdle: Duration, maxWait: Duration, checkHealth: Boolean, evictTime: Duration) = { + this(minConns, maxConns, checkIdle, maxWait, checkHealth, evictTime, Map.empty) + } def apply(dbhosts: List[String], dbname: String, username: String, password: String, urlOptions: Map[String, String]) = { - val pool = new ApachePoolingDatabase( + val finalUrlOptions = + if (urlOptions eq null) { + defaultUrlOptions + } else { + defaultUrlOptions ++ urlOptions + } + + new ApachePoolingDatabase( dbhosts, dbname, username, password, - urlOptions, + finalUrlOptions, minOpenConnections, maxOpenConnections, checkConnectionHealthWhenIdleFor, maxWaitForConnectionReservation, checkConnectionHealthOnReservation, - evictConnectionIfIdleFor) - pool + evictConnectionIfIdleFor + ) } } @@ -80,5 +92,5 @@ class ApachePoolingDatabase( def open() = poolingDataSource.getConnection() - override def toString = dbhosts.first + "_" + dbname + override def toString = dbhosts.head + "_" + dbname } diff --git a/src/main/scala/com/twitter/querulous/database/AutoDisablingDatabase.scala b/src/main/scala/com/twitter/querulous/database/AutoDisablingDatabase.scala index ed1c29f..1f14388 100644 --- a/src/main/scala/com/twitter/querulous/database/AutoDisablingDatabase.scala +++ b/src/main/scala/com/twitter/querulous/database/AutoDisablingDatabase.scala @@ -1,15 +1,16 @@ package com.twitter.querulous.database -import com.twitter.xrayspecs.Duration -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.querulous.AutoDisabler +import com.twitter.util.Duration +import com.twitter.util.TimeConversions._ import java.sql.{Connection, SQLException, SQLIntegrityConstraintViolationException} -class AutoDisablingDatabaseFactory(databaseFactory: DatabaseFactory, disableErrorCount: Int, disableDuration: Duration) extends DatabaseFactory { +class AutoDisablingDatabaseFactory(val databaseFactory: DatabaseFactory, val disableErrorCount: Int, val disableDuration: Duration) extends DatabaseFactory { def apply(dbhosts: List[String], dbname: String, username: String, password: String, urlOptions: Map[String, String]) = { new AutoDisablingDatabase( databaseFactory(dbhosts, dbname, username, password, urlOptions), - dbhosts.first, + dbhosts.head, disableErrorCount, disableDuration) } diff --git a/src/main/scala/com/twitter/querulous/database/Database.scala b/src/main/scala/com/twitter/querulous/database/Database.scala index d6d0c78..abcdbb8 100644 --- a/src/main/scala/com/twitter/querulous/database/Database.scala +++ b/src/main/scala/com/twitter/querulous/database/Database.scala @@ -1,46 +1,18 @@ package com.twitter.querulous.database +import com.twitter.querulous._ import java.sql.Connection -import com.twitter.xrayspecs.TimeConversions._ -import net.lag.configgy.ConfigMap - - -object DatabaseFactory { - def fromConfig(config: ConfigMap, statsCollector: Option[StatsCollector]) = { - var factory: DatabaseFactory = if (config.contains("size_min")) { - new ApachePoolingDatabaseFactory( - config("size_min").toInt, - config("size_max").toInt, - config("test_idle_msec").toLong.millis, - config("max_wait").toLong.millis, - config("test_on_borrow").toBoolean, - config("min_evictable_idle_msec").toLong.millis) - } else { - new SingleConnectionDatabaseFactory() - } - statsCollector.foreach { stats => - factory = new StatsCollectingDatabaseFactory(factory, stats) - } - config.getConfigMap("timeout").foreach { timeoutConfig => - factory = new TimingOutDatabaseFactory(factory, - timeoutConfig("pool_size").toInt, - timeoutConfig("queue_size").toInt, - timeoutConfig("open").toLong.millis, - timeoutConfig("initialize").toLong.millis, - config("size_max").toInt) - } - new MemoizingDatabaseFactory(factory) - } -} +import com.twitter.querulous.StatsCollector +import com.twitter.util.TimeConversions._ trait DatabaseFactory { def apply(dbhosts: List[String], dbname: String, username: String, password: String, urlOptions: Map[String, String]): Database def apply(dbhosts: List[String], dbname: String, username: String, password: String): Database = - apply(dbhosts, dbname, username, password, null) + apply(dbhosts, dbname, username, password, Map.empty) def apply(dbhosts: List[String], username: String, password: String): Database = - apply(dbhosts, null, username, password, null) + apply(dbhosts, null, username, password, Map.empty) } trait Database { @@ -57,14 +29,18 @@ trait Database { } } + val defaultUrlOptions = Map( + "useUnicode" -> "true", + "characterEncoding" -> "UTF-8", + "connectTimeout" -> "500" + ) + protected def url(dbhosts: List[String], dbname: String, urlOptions: Map[String, String]) = { val dbnameSegment = if (dbname == null) "" else ("/" + dbname) - val urlOptsSegment = if (urlOptions == null) { - "?useUnicode=true&characterEncoding=UTF-8" - } else { - "?" + urlOptions.keys.map( k => k + "=" + urlOptions(k) ).mkString("&") - } - "jdbc:mysql://" + dbhosts.mkString(",") + dbnameSegment + urlOptsSegment + val finalUrlOpts = defaultUrlOptions ++ urlOptions + val urlOptsSegment = finalUrlOpts.map(Function.tupled((k, v) => k+"="+v )).mkString("&") + + "jdbc:mysql://" + dbhosts.mkString(",") + dbnameSegment + "?" + urlOptsSegment } } diff --git a/src/main/scala/com/twitter/querulous/database/MemoizingDatabase.scala b/src/main/scala/com/twitter/querulous/database/MemoizingDatabase.scala index ec0ffd7..e13bf5d 100644 --- a/src/main/scala/com/twitter/querulous/database/MemoizingDatabase.scala +++ b/src/main/scala/com/twitter/querulous/database/MemoizingDatabase.scala @@ -2,12 +2,12 @@ package com.twitter.querulous.database import scala.collection.mutable -class MemoizingDatabaseFactory(databaseFactory: DatabaseFactory) extends DatabaseFactory { +class MemoizingDatabaseFactory(val databaseFactory: DatabaseFactory) extends DatabaseFactory { private val databases = new mutable.HashMap[String, Database] with mutable.SynchronizedMap[String, Database] def apply(dbhosts: List[String], dbname: String, username: String, password: String, urlOptions: Map[String, String]) = synchronized { databases.getOrElseUpdate( - dbhosts.first + "/" + dbname, + dbhosts.head + "/" + dbname, databaseFactory(dbhosts, dbname, username, password, urlOptions)) } diff --git a/src/main/scala/com/twitter/querulous/database/SingleConnectionDatabase.scala b/src/main/scala/com/twitter/querulous/database/SingleConnectionDatabase.scala index d582963..44359da 100644 --- a/src/main/scala/com/twitter/querulous/database/SingleConnectionDatabase.scala +++ b/src/main/scala/com/twitter/querulous/database/SingleConnectionDatabase.scala @@ -4,9 +4,18 @@ import org.apache.commons.dbcp.DriverManagerConnectionFactory import java.sql.{SQLException, Connection} -class SingleConnectionDatabaseFactory extends DatabaseFactory { +class SingleConnectionDatabaseFactory(defaultUrlOptions: Map[String, String]) extends DatabaseFactory { + def this() = this(Map.empty) + def apply(dbhosts: List[String], dbname: String, username: String, password: String, urlOptions: Map[String, String]) = { - new SingleConnectionDatabase(dbhosts, dbname, username, password, urlOptions) + val finalUrlOptions = + if (urlOptions eq null) { + defaultUrlOptions + } else { + defaultUrlOptions ++ urlOptions + } + + new SingleConnectionDatabase(dbhosts, dbname, username, password, finalUrlOptions) } } @@ -24,5 +33,5 @@ class SingleConnectionDatabase(dbhosts: List[String], dbname: String, username: } def open() = connectionFactory.createConnection() - override def toString = dbhosts.first + "_" + dbname + override def toString = dbhosts.head + "_" + dbname } diff --git a/src/main/scala/com/twitter/querulous/database/StatsCollectingDatabase.scala b/src/main/scala/com/twitter/querulous/database/StatsCollectingDatabase.scala index 8d77d31..d8af7ef 100644 --- a/src/main/scala/com/twitter/querulous/database/StatsCollectingDatabase.scala +++ b/src/main/scala/com/twitter/querulous/database/StatsCollectingDatabase.scala @@ -1,5 +1,6 @@ package com.twitter.querulous.database +import com.twitter.querulous.StatsCollector import java.sql.Connection class StatsCollectingDatabaseFactory( @@ -15,14 +16,26 @@ class StatsCollectingDatabase(database: Database, stats: StatsCollector) extends Database { override def open(): Connection = { - stats.time("database-open-timing") { - database.open() + stats.time("db-open-timing") { + try { + database.open() + } catch { + case e: SqlDatabaseTimeoutException => + stats.incr("db-open-timeout-count", 1) + throw e + } } } override def close(connection: Connection) = { - stats.time("database-close-timing") { - database.close(connection) + stats.time("db-close-timing") { + try { + database.close(connection) + } catch { + case e: SqlDatabaseTimeoutException => + stats.incr("db-close-timeout-count", 1) + throw e + } } } } diff --git a/src/main/scala/com/twitter/querulous/database/TimingOutDatabase.scala b/src/main/scala/com/twitter/querulous/database/TimingOutDatabase.scala index d98b582..a4b9751 100644 --- a/src/main/scala/com/twitter/querulous/database/TimingOutDatabase.scala +++ b/src/main/scala/com/twitter/querulous/database/TimingOutDatabase.scala @@ -1,28 +1,35 @@ package com.twitter.querulous.database +import com.twitter.querulous.{FutureTimeout, TimeoutException} import java.sql.{Connection, SQLException} import java.util.concurrent.{TimeoutException => JTimeoutException, _} -import com.twitter.xrayspecs.Duration -import net.lag.logging.Logger +import com.twitter.util.Duration class SqlDatabaseTimeoutException(msg: String, val timeout: Duration) extends SQLException(msg) -class TimingOutDatabaseFactory(databaseFactory: DatabaseFactory, poolSize: Int, queueSize: Int, openTimeout: Duration, initialTimeout: Duration, maxConnections: Int) extends DatabaseFactory { - def apply(dbhosts: List[String], dbname: String, username: String, password: String, urlOptions: Map[String, String]) = { +class TimingOutDatabaseFactory( + val databaseFactory: DatabaseFactory, + val poolSize: Int, + val queueSize: Int, + val openTimeout: Duration, + val maxConnections: Int) +extends DatabaseFactory { + + private def newTimeoutPool() = new FutureTimeout(poolSize, queueSize) + + def apply(dbhosts: List[String], dbname: String, username: String, password: String, + urlOptions: Map[String, String]) = { val dbLabel = if (dbname != null) dbname else "(null)" - new TimingOutDatabase(databaseFactory(dbhosts, dbname, username, password, urlOptions), dbhosts, dbLabel, poolSize, queueSize, openTimeout, initialTimeout, maxConnections) + new TimingOutDatabase(databaseFactory(dbhosts, dbname, username, password, urlOptions), + dbhosts, dbLabel, newTimeoutPool(), openTimeout, maxConnections) } } -class TimingOutDatabase(database: Database, dbhosts: List[String], dbname: String, poolSize: Int, queueSize: Int, openTimeout: Duration, initialTimeout: Duration, maxConnections: Int) extends Database { - private val timeout = new FutureTimeout(poolSize, queueSize) - private val log = Logger.get(getClass.getName) - - // FIXME not working yet. - //greedilyInstantiateConnections() - +class TimingOutDatabase(database: Database, dbhosts: List[String], dbname: String, + timeout: FutureTimeout, openTimeout: Duration, + maxConnections: Int) extends Database { private def getConnection(wait: Duration) = { try { timeout(wait) { @@ -36,13 +43,6 @@ class TimingOutDatabase(database: Database, dbhosts: List[String], dbname: Strin } } - private def greedilyInstantiateConnections() = { - log.info("Connecting to %s:%s", dbhosts.mkString(","), dbname) - (0 until maxConnections).force.map { i => - getConnection(initialTimeout) - }.map(_.close) - } - override def open() = getConnection(openTimeout) def close(connection: Connection) { database.close(connection) } diff --git a/src/main/scala/com/twitter/querulous/evaluator/AutoDisablingQueryEvaluator.scala b/src/main/scala/com/twitter/querulous/evaluator/AutoDisablingQueryEvaluator.scala index f4f6170..0230b20 100644 --- a/src/main/scala/com/twitter/querulous/evaluator/AutoDisablingQueryEvaluator.scala +++ b/src/main/scala/com/twitter/querulous/evaluator/AutoDisablingQueryEvaluator.scala @@ -3,8 +3,10 @@ package com.twitter.querulous.evaluator import java.sql.ResultSet import java.sql.{SQLException, SQLIntegrityConstraintViolationException} import com.mysql.jdbc.exceptions.MySQLIntegrityConstraintViolationException -import com.twitter.xrayspecs.{Time, Duration} -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.querulous.AutoDisabler +import com.twitter.util.{Time, Duration} +import com.twitter.util.TimeConversions._ + class AutoDisablingQueryEvaluatorFactory( queryEvaluatorFactory: QueryEvaluatorFactory, @@ -28,9 +30,6 @@ class AutoDisablingQueryEvaluator ( protected val disableErrorCount: Int, protected val disableDuration: Duration) extends QueryEvaluatorProxy(queryEvaluator) with AutoDisabler { - private var disabledUntil: Time = Time.never - private var consecutiveErrors = 0 - override protected def delegate[A](f: => A) = { throwIfDisabled() try { diff --git a/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluator.scala b/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluator.scala index af68ed4..1de3531 100644 --- a/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluator.scala +++ b/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluator.scala @@ -1,29 +1,12 @@ package com.twitter.querulous.evaluator +import com.twitter.querulous._ import java.sql.ResultSet -import com.twitter.xrayspecs.TimeConversions._ -import net.lag.configgy.ConfigMap -import database._ -import query._ - - -object QueryEvaluatorFactory { - def fromConfig(config: ConfigMap, databaseFactory: DatabaseFactory, queryFactory: QueryFactory): QueryEvaluatorFactory = { - var factory: QueryEvaluatorFactory = new StandardQueryEvaluatorFactory(databaseFactory, queryFactory) - config.getConfigMap("disable").foreach { disableConfig => - factory = new AutoDisablingQueryEvaluatorFactory(factory, - disableConfig("error_count").toInt, - disableConfig("seconds").toInt.seconds) - } - factory - } +import com.twitter.util.TimeConversions._ +import com.twitter.querulous.StatsCollector +import com.twitter.querulous.database._ +import com.twitter.querulous.query._ - def fromConfig(config: ConfigMap, statsCollector: Option[StatsCollector]): QueryEvaluatorFactory = { - fromConfig(config, - DatabaseFactory.fromConfig(config.configMap("connection_pool"), statsCollector), - QueryFactory.fromConfig(config, statsCollector)) - } -} object QueryEvaluator extends QueryEvaluatorFactory { private def createEvaluatorFactory() = { @@ -45,39 +28,56 @@ trait QueryEvaluatorFactory { } def apply(dbhosts: List[String], dbname: String, username: String, password: String): QueryEvaluator = { - apply(dbhosts, dbname, username, password, null) + apply(dbhosts, dbname, username, password, Map[String,String]()) } def apply(dbhost: String, dbname: String, username: String, password: String): QueryEvaluator = { - apply(List(dbhost), dbname, username, password, null) + apply(List(dbhost), dbname, username, password, Map[String,String]()) } def apply(dbhost: String, username: String, password: String): QueryEvaluator = { - apply(List(dbhost), null, username, password, null) + apply(List(dbhost), null, username, password, Map[String,String]()) } def apply(dbhosts: List[String], username: String, password: String): QueryEvaluator = { - apply(dbhosts, null, username, password, null) + apply(dbhosts, null, username, password, Map[String,String]()) } - def apply(config: ConfigMap): QueryEvaluator = { - apply( - config.getList("hostname").toList, - config.getString("database").getOrElse(null), - config("username"), - config.getString("password").getOrElse(null), - // this is so lame, why do I have to cast this back? - config.getConfigMap("url_options").map(_.asMap.asInstanceOf[Map[String, String]]).getOrElse(null) - ) + def apply(connection: config.Connection): QueryEvaluator = { + apply(connection.hostnames.toList, connection.database, connection.username, connection.password, connection.urlOptions) } } +class ParamsApplier(query: Query) { + def apply(params: Any*) = query.addParams(params) +} + trait QueryEvaluator { - def select[A](query: String, params: Any*)(f: ResultSet => A): Seq[A] - def selectOne[A](query: String, params: Any*)(f: ResultSet => A): Option[A] - def count(query: String, params: Any*): Int - def execute(query: String, params: Any*): Int + def select[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A): Seq[A] + def select[A](query: String, params: Any*)(f: ResultSet => A): Seq[A] = + select(QueryClass.Select, query, params: _*)(f) + + def selectOne[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A): Option[A] + def selectOne[A](query: String, params: Any*)(f: ResultSet => A): Option[A] = + selectOne(QueryClass.Select, query, params: _*)(f) + + def count(queryClass: QueryClass, query: String, params: Any*): Int + def count(query: String, params: Any*): Int = + count(QueryClass.Select, query, params: _*) + + def execute(queryClass: QueryClass, query: String, params: Any*): Int + def execute(query: String, params: Any*): Int = + execute(QueryClass.Execute, query, params: _*) + + def executeBatch(queryClass: QueryClass, query: String)(f: ParamsApplier => Unit): Int + def executeBatch(query: String)(f: ParamsApplier => Unit): Int = + executeBatch(QueryClass.Execute, query)(f) + def nextId(tableName: String): Long - def insert(query: String, params: Any*): Long + + def insert(queryClass: QueryClass, query: String, params: Any*): Long + def insert(query: String, params: Any*): Long = + insert(QueryClass.Execute, query, params: _*) + def transaction[T](f: Transaction => T): T } diff --git a/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluatorProxy.scala b/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluatorProxy.scala index b9fec3b..77177bc 100644 --- a/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluatorProxy.scala +++ b/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluatorProxy.scala @@ -3,31 +3,37 @@ package com.twitter.querulous.evaluator import java.sql.ResultSet import java.sql.{SQLException, SQLIntegrityConstraintViolationException} import com.mysql.jdbc.exceptions.MySQLIntegrityConstraintViolationException -import com.twitter.xrayspecs.{Time, Duration} -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.util.{Time, Duration} +import com.twitter.util.TimeConversions._ +import com.twitter.querulous.query.{QueryClass, Query} + abstract class QueryEvaluatorProxy(queryEvaluator: QueryEvaluator) extends QueryEvaluator { - def select[A](query: String, params: Any*)(f: ResultSet => A) = { - delegate(queryEvaluator.select(query, params: _*)(f)) + def select[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = { + delegate(queryEvaluator.select(queryClass, query, params: _*)(f)) + } + + def selectOne[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = { + delegate(queryEvaluator.selectOne(queryClass, query, params: _*)(f)) } - def selectOne[A](query: String, params: Any*)(f: ResultSet => A) = { - delegate(queryEvaluator.selectOne(query, params: _*)(f)) + def execute(queryClass: QueryClass, query: String, params: Any*) = { + delegate(queryEvaluator.execute(queryClass, query, params: _*)) } - def execute(query: String, params: Any*) = { - delegate(queryEvaluator.execute(query, params: _*)) + def executeBatch(queryClass: QueryClass, query: String)(f: ParamsApplier => Unit) = { + delegate(queryEvaluator.executeBatch(queryClass, query)(f)) } - def count(query: String, params: Any*) = { - delegate(queryEvaluator.count(query, params: _*)) + def count(queryClass: QueryClass, query: String, params: Any*) = { + delegate(queryEvaluator.count(queryClass, query, params: _*)) } def nextId(tableName: String) = { delegate(queryEvaluator.nextId(tableName)) } - def insert(query: String, params: Any*) = { + def insert(queryClass: QueryClass, query: String, params: Any*) = { delegate(queryEvaluator.insert(query, params: _*)) } diff --git a/src/main/scala/com/twitter/querulous/evaluator/StandardQueryEvaluator.scala b/src/main/scala/com/twitter/querulous/evaluator/StandardQueryEvaluator.scala index f0c7a7a..78e1e6b 100644 --- a/src/main/scala/com/twitter/querulous/evaluator/StandardQueryEvaluator.scala +++ b/src/main/scala/com/twitter/querulous/evaluator/StandardQueryEvaluator.scala @@ -4,7 +4,7 @@ import java.sql.ResultSet import org.apache.commons.dbcp.{DriverManagerConnectionFactory, PoolableConnectionFactory, PoolingDataSource} import org.apache.commons.pool.impl.GenericObjectPool import com.twitter.querulous.database.{Database, DatabaseFactory} -import com.twitter.querulous.query.QueryFactory +import com.twitter.querulous.query.{Query, QueryClass, QueryFactory} class StandardQueryEvaluatorFactory( databaseFactory: DatabaseFactory, @@ -19,12 +19,31 @@ class StandardQueryEvaluatorFactory( class StandardQueryEvaluator(protected val database: Database, queryFactory: QueryFactory) extends QueryEvaluator { - def select[A](query: String, params: Any*)(f: ResultSet => A) = withTransaction(_.select(query, params: _*)(f)) - def selectOne[A](query: String, params: Any*)(f: ResultSet => A) = withTransaction(_.selectOne(query, params: _*)(f)) - def count(query: String, params: Any*) = withTransaction(_.count(query, params: _*)) - def execute(query: String, params: Any*) = withTransaction(_.execute(query, params: _*)) + def select[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = { + withTransaction(_.select(queryClass, query, params: _*)(f)) + } + + def selectOne[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = { + withTransaction(_.selectOne(queryClass, query, params: _*)(f)) + } + + def count(queryClass: QueryClass, query: String, params: Any*) = { + withTransaction(_.count(queryClass, query, params: _*)) + } + + def execute(queryClass: QueryClass, query: String, params: Any*) = { + withTransaction(_.execute(queryClass, query, params: _*)) + } + + def executeBatch(queryClass: QueryClass, query: String)(f: ParamsApplier => Unit) = { + withTransaction(_.executeBatch(queryClass, query)(f)) + } + def nextId(tableName: String) = withTransaction(_.nextId(tableName)) - def insert(query: String, params: Any*) = withTransaction(_.insert(query, params: _*)) + + def insert(queryClass: QueryClass, query: String, params: Any*) = { + withTransaction(_.insert(queryClass, query, params: _*)) + } def transaction[T](f: Transaction => T) = { withTransaction { transaction => @@ -35,7 +54,9 @@ class StandardQueryEvaluator(protected val database: Database, queryFactory: Que rv } catch { case e: Throwable => - transaction.rollback() + try { + transaction.rollback() + } catch { case _ => () } throw e } } diff --git a/src/main/scala/com/twitter/querulous/evaluator/Transaction.scala b/src/main/scala/com/twitter/querulous/evaluator/Transaction.scala index e438412..6f0fe3c 100644 --- a/src/main/scala/com/twitter/querulous/evaluator/Transaction.scala +++ b/src/main/scala/com/twitter/querulous/evaluator/Transaction.scala @@ -1,24 +1,29 @@ package com.twitter.querulous.evaluator import java.sql.{ResultSet, SQLException, SQLIntegrityConstraintViolationException, Connection} -import com.twitter.querulous.query.QueryFactory +import com.twitter.querulous.query.{QueryClass, QueryFactory, Query} class Transaction(queryFactory: QueryFactory, connection: Connection) extends QueryEvaluator { - def select[A](query: String, params: Any*)(f: ResultSet => A) = { - queryFactory(connection, query, params: _*).select(f) + def select[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = { + queryFactory(connection, queryClass, query, params: _*).select(f) } - def selectOne[A](query: String, params: Any*)(f: ResultSet => A) = { - val results = select(query, params: _*)(f) - if (results.isEmpty) None else Some(results.first) + def selectOne[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = { + select(queryClass, query, params: _*)(f).headOption } - def count(query: String, params: Any*) = { - selectOne(query, params: _*)(_.getInt("count(*)")) getOrElse 0 + def count(queryClass: QueryClass, query: String, params: Any*) = { + selectOne(queryClass, query, params: _*)(_.getInt("count(*)")) getOrElse 0 } - def execute(query: String, params: Any*) = { - queryFactory(connection, query, params: _*).execute() + def execute(queryClass: QueryClass, query: String, params: Any*) = { + queryFactory(connection, queryClass, query, params: _*).execute() + } + + def executeBatch(queryClass: QueryClass, queryString: String)(f: ParamsApplier => Unit) = { + val query = queryFactory(connection, queryClass, queryString) + f(new ParamsApplier(query)) + query.execute } def nextId(tableName: String) = { @@ -26,8 +31,8 @@ class Transaction(queryFactory: QueryFactory, connection: Connection) extends Qu selectOne("SELECT LAST_INSERT_ID()") { _.getLong("LAST_INSERT_ID()") } getOrElse 0L } - def insert(query: String, params: Any*): Long = { - execute(query, params: _*) + def insert(queryClass: QueryClass, query: String, params: Any*): Long = { + execute(queryClass, query, params: _*) selectOne("SELECT LAST_INSERT_ID()") { _.getLong("LAST_INSERT_ID()") } getOrElse { throw new SQLIntegrityConstraintViolationException } diff --git a/src/main/scala/com/twitter/querulous/query/DebuggingQuery.scala b/src/main/scala/com/twitter/querulous/query/DebuggingQuery.scala index 0c48681..cad1323 100644 --- a/src/main/scala/com/twitter/querulous/query/DebuggingQuery.scala +++ b/src/main/scala/com/twitter/querulous/query/DebuggingQuery.scala @@ -3,8 +3,8 @@ package com.twitter.querulous.query import java.sql.{Timestamp, Connection} class DebuggingQueryFactory(queryFactory: QueryFactory, log: String => Unit) extends QueryFactory { - def apply(connection: Connection, query: String, params: Any*) = { - new DebuggingQuery(queryFactory(connection, query, params: _*), log, query, params) + def apply(connection: Connection, queryClass: QueryClass, query: String, params: Any*) = { + new DebuggingQuery(queryFactory(connection, queryClass, query, params: _*), log, query, params) } } diff --git a/src/main/scala/com/twitter/querulous/query/Query.scala b/src/main/scala/com/twitter/querulous/query/Query.scala index 22b50b4..873f01a 100644 --- a/src/main/scala/com/twitter/querulous/query/Query.scala +++ b/src/main/scala/com/twitter/querulous/query/Query.scala @@ -1,67 +1,19 @@ package com.twitter.querulous.query +import com.twitter.querulous._ import java.sql.{ResultSet, Connection} import scala.collection.mutable -import com.twitter.xrayspecs.Duration -import com.twitter.xrayspecs.TimeConversions._ -import net.lag.configgy.ConfigMap -import net.lag.logging.Logger - +import com.twitter.querulous.StatsCollector +import com.twitter.util.Duration +import com.twitter.util.TimeConversions._ trait QueryFactory { - def apply(connection: Connection, queryString: String, params: Any*): Query + def apply(connection: Connection, queryClass: QueryClass, queryString: String, params: Any*): Query } trait Query { def select[A](f: ResultSet => A): Seq[A] def execute(): Int + def addParams(params: Any*) def cancel() } - -object QueryFactory { - private def convertConfigMap(queryMap: ConfigMap) = { - val queryInfo = new mutable.HashMap[String, (String, Duration)] - for (key <- queryMap.keys) { - val pair = queryMap.getList(key) - val query = pair(0) - val timeout = pair(1).toLong.millis - queryInfo += (query -> (key, timeout)) - } - queryInfo - } - - /* - query_timeout_default = 3000 - queries { - select_source_id_for_update = ["SELECT * FROM ? WHERE source_id = ? FOR UPDATE", 3000] - } - retries = 3 - debug = false - */ - def fromConfig(config: ConfigMap, statsCollector: Option[StatsCollector]): QueryFactory = { - var queryFactory: QueryFactory = new SqlQueryFactory - config.getConfigMap("queries") match { - case Some(queryMap) => - val queryInfo = convertConfigMap(queryMap) - val timeout = config("query_timeout_default").toLong.millis - queryFactory = new TimingOutStatsCollectingQueryFactory(queryFactory, queryInfo, timeout, - statsCollector.get) - case None => - config.getInt("query_timeout_default").foreach { timeout => - queryFactory = new TimingOutQueryFactory(queryFactory, timeout.millis) - } - statsCollector.foreach { stats => - queryFactory = new StatsCollectingQueryFactory(queryFactory, stats) - } - } - - config.getInt("retries").foreach { retries => - queryFactory = new RetryingQueryFactory(queryFactory, retries) - } - if (config.getBool("debug", false)) { - val log = Logger.get(getClass.getName) - queryFactory = new DebuggingQueryFactory(queryFactory, { s => log.debug(s) }) - } - queryFactory - } -} \ No newline at end of file diff --git a/src/main/scala/com/twitter/querulous/query/QueryClass.scala b/src/main/scala/com/twitter/querulous/query/QueryClass.scala new file mode 100644 index 0000000..39618e4 --- /dev/null +++ b/src/main/scala/com/twitter/querulous/query/QueryClass.scala @@ -0,0 +1,19 @@ +package com.twitter.querulous.query + +import scala.collection.mutable + +class QueryClass(val name: String) + +object QueryClass { + val classes = mutable.Map[String, QueryClass]() + + def apply(name: String) = { + classes(name) = new QueryClass(name) + lookup(name) + } + + def lookup(name: String) = classes(name) + + val Select = QueryClass("select") + val Execute = QueryClass("execute") +} diff --git a/src/main/scala/com/twitter/querulous/query/QueryProxy.scala b/src/main/scala/com/twitter/querulous/query/QueryProxy.scala index 145a8a7..697daaf 100644 --- a/src/main/scala/com/twitter/querulous/query/QueryProxy.scala +++ b/src/main/scala/com/twitter/querulous/query/QueryProxy.scala @@ -9,5 +9,7 @@ abstract class QueryProxy(query: Query) extends Query { def cancel() = query.cancel() + def addParams(params: Any*) = query.addParams(params) + protected def delegate[A](f: => A) = f } diff --git a/src/main/scala/com/twitter/querulous/query/RetryingQuery.scala b/src/main/scala/com/twitter/querulous/query/RetryingQuery.scala index 6189cf7..a413ff9 100644 --- a/src/main/scala/com/twitter/querulous/query/RetryingQuery.scala +++ b/src/main/scala/com/twitter/querulous/query/RetryingQuery.scala @@ -1,12 +1,12 @@ package com.twitter.querulous.query import java.sql.{SQLException, Connection} -import com.twitter.xrayspecs.Duration +import com.twitter.util.Duration class RetryingQueryFactory(queryFactory: QueryFactory, retries: Int) extends QueryFactory { - def apply(connection: Connection, query: String, params: Any*) = { - new RetryingQuery(queryFactory(connection, query, params: _*), retries) + def apply(connection: Connection, queryClass: QueryClass, query: String, params: Any*) = { + new RetryingQuery(queryFactory(connection, queryClass, query, params: _*), retries) } } diff --git a/src/main/scala/com/twitter/querulous/query/SqlQuery.scala b/src/main/scala/com/twitter/querulous/query/SqlQuery.scala index 60577ca..4e655d5 100644 --- a/src/main/scala/com/twitter/querulous/query/SqlQuery.scala +++ b/src/main/scala/com/twitter/querulous/query/SqlQuery.scala @@ -6,13 +6,13 @@ import java.util.regex.Pattern import scala.collection.mutable class SqlQueryFactory extends QueryFactory { - def apply(connection: Connection, query: String, params: Any*) = { + def apply(connection: Connection, queryClass: QueryClass, query: String, params: Any*) = { new SqlQuery(connection, query, params: _*) } } -class TooFewQueryParametersException extends Exception -class TooManyQueryParametersException extends Exception +class TooFewQueryParametersException(t: Throwable) extends Exception(t) +class TooManyQueryParametersException(t: Throwable) extends Exception(t) sealed abstract case class NullValue(typeVal: Int) object NullValues { @@ -43,7 +43,13 @@ object NullValues { class SqlQuery(connection: Connection, query: String, params: Any*) extends Query { - val statement = buildStatement(connection, query, params: _*) + def this(connection: Connection, query: String) = { + this(connection, query, Nil) + } + + var paramsInitialized = false + var statement = buildStatement(connection, query, params: _*) + var batchMode = false def select[A](f: ResultSet => A): Seq[A] = { withStatement { @@ -61,9 +67,22 @@ class SqlQuery(connection: Connection, query: String, params: Any*) extends Quer } } + def addParams(params: Any*) = { + if(paramsInitialized && !batchMode) { + statement.addBatch() + } + setBindVariable(statement, 1, params) + statement.addBatch() + batchMode = true + } + def execute() = { withStatement { - statement.executeUpdate() + if(batchMode) { + statement.executeBatch().foldLeft(0)(_+_) + } else { + statement.executeUpdate() + } } } @@ -94,26 +113,37 @@ class SqlQuery(connection: Connection, query: String, params: Any*) extends Quer statement } - private def expandArrayParams(query: String, params: Any*) = { + private def expandArrayParams(query: String, params: Any*): String = { + if(params.isEmpty){ + return query + } val p = Pattern.compile("\\?") val m = p.matcher(query) val result = new StringBuffer var i = 0 + + def marks(param: Any): String = param match { + case t2: (_,_) => "(?,?)" + case t3: (_,_,_) => "(?,?,?)" + case t4: (_,_,_,_) => "(?,?,?,?)" + case a: Array[Byte] => "?" + case s: Seq[_] => s.map(marks(_)).mkString(",") + case _ => "?" + } + while (m.find) { try { - val questionMarks = params(i) match { - case a: Array[Byte] => "?" - case s: Seq[_] => s.map { _ => "?" }.mkString(",") - case _ => "?" - } - m.appendReplacement(result, questionMarks) + m.appendReplacement(result, marks(params(i))) } catch { - case e: ArrayIndexOutOfBoundsException => throw new TooFewQueryParametersException - case e: NoSuchElementException => throw new TooFewQueryParametersException + case e: ArrayIndexOutOfBoundsException => + throw new TooFewQueryParametersException(e) + case e: NoSuchElementException => + throw new TooFewQueryParametersException(e) } i += 1 } m.appendTail(result) + paramsInitialized = true result.toString } @@ -122,6 +152,12 @@ class SqlQuery(connection: Connection, query: String, params: Any*) extends Quer try { param match { + case (a, b) => + index = setBindVariable(statement, index, List(a, b)) - 1 + case (a, b, c) => + index = setBindVariable(statement, index, List(a, b, c)) - 1 + case (a, b, c, d) => + index = setBindVariable(statement, index, List(a, b, c, d)) - 1 case s: String => statement.setString(index, s) case l: Long => @@ -146,7 +182,8 @@ class SqlQuery(connection: Connection, query: String, params: Any*) extends Quer } index + 1 } catch { - case e: SQLException => throw new TooManyQueryParametersException + case e: SQLException => + throw new TooManyQueryParametersException(e) } } } diff --git a/src/main/scala/com/twitter/querulous/query/StatsCollectingQuery.scala b/src/main/scala/com/twitter/querulous/query/StatsCollectingQuery.scala index bbd74ea..9f89e64 100644 --- a/src/main/scala/com/twitter/querulous/query/StatsCollectingQuery.scala +++ b/src/main/scala/com/twitter/querulous/query/StatsCollectingQuery.scala @@ -1,27 +1,30 @@ package com.twitter.querulous.query +import com.twitter.querulous.StatsCollector import java.sql.{ResultSet, Connection} class StatsCollectingQueryFactory(queryFactory: QueryFactory, stats: StatsCollector) extends QueryFactory { - def apply(connection: Connection, query: String, params: Any*) = { - new StatsCollectingQuery(queryFactory(connection, query, params: _*), stats) + def apply(connection: Connection, queryClass: QueryClass, query: String, params: Any*) = { + new StatsCollectingQuery(queryFactory(connection, queryClass, query, params: _*), queryClass, stats) } } -class StatsCollectingQuery(query: Query, stats: StatsCollector) extends QueryProxy(query) { - override def select[A](f: ResultSet => A) = { - stats.incr("db-select-count", 1) - delegate(query.select(f)) - } - - override def execute() = { - stats.incr("db-execute-count", 1) - delegate(query.execute()) - } - +class StatsCollectingQuery(query: Query, queryClass: QueryClass, stats: StatsCollector) extends QueryProxy(query) { override def delegate[A](f: => A) = { - stats.time("db-timing")(f) + stats.incr("db-" + queryClass.name + "-count", 1) + stats.time("db-" + queryClass.name + "-timing") { + stats.time("db-timing") { + try { + f + } catch { + case e: SqlQueryTimeoutException => + stats.incr("db-query-timeout-count", 1) + stats.incr("db-query-" + queryClass.name + "-timeout-count", 1) + throw e + } + } + } } } diff --git a/src/main/scala/com/twitter/querulous/query/TimingOutQuery.scala b/src/main/scala/com/twitter/querulous/query/TimingOutQuery.scala index 02c88e8..ecb7631 100644 --- a/src/main/scala/com/twitter/querulous/query/TimingOutQuery.scala +++ b/src/main/scala/com/twitter/querulous/query/TimingOutQuery.scala @@ -1,8 +1,10 @@ package com.twitter.querulous.query import java.sql.{SQLException, Connection} -import com.twitter.xrayspecs.Duration -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.util.Duration +import com.twitter.util.TimeConversions._ +import scala.collection.Map +import com.twitter.querulous.{Timeout, TimeoutException} class SqlQueryTimeoutException(val timeout: Duration) extends SQLException("Query timeout: " + timeout.inMillis + " msec") @@ -14,29 +16,26 @@ class SqlQueryTimeoutException(val timeout: Duration) extends SQLException("Quer *

Note that queries timing out promptly is based upon {@link java.sql.Statement#cancel} working * and executing promptly for the JDBC driver in use. */ -class TimingOutQueryFactory(queryFactory: QueryFactory, timeout: Duration, cancelTimeout: Duration) extends QueryFactory { - def this(queryFactory: QueryFactory, timeout: Duration) = this(queryFactory, timeout, 0.millis) +class TimingOutQueryFactory(queryFactory: QueryFactory, val timeout: Duration, val cancelOnTimeout: Boolean) + extends QueryFactory { + + def this(queryFactory: QueryFactory, timeout: Duration) = this(queryFactory, timeout, false) - def apply(connection: Connection, query: String, params: Any*) = { - new TimingOutQuery(queryFactory(connection, query, params: _*), connection, timeout, cancelTimeout) + def apply(connection: Connection, queryClass: QueryClass, query: String, params: Any*) = { + new TimingOutQuery(queryFactory(connection, queryClass, query, params: _*), connection, timeout, cancelOnTimeout) } } /** - * A {@code QueryFactory} that creates {@link Query}s that execute subject to the {@code timeouts} - * specified for individual queries. An attempt to {@link Query#cancel} a query is made if the - * timeout expires. - * - *

Note that queries timing out promptly is based upon {@link java.sql.Statement#cancel} working - * and executing promptly for the JDBC driver in use. + * A `QueryFactory` that creates `Query`s that execute subject to the timeouts + * specified for individual query classes. */ -class PerQueryTimingOutQueryFactory(queryFactory: QueryFactory, timeouts: Map[String, Duration], cancelTimeout: Duration) +class PerQueryTimingOutQueryFactory(queryFactory: QueryFactory, val timeouts: Map[QueryClass, (Duration, Boolean)]) extends QueryFactory { - def this(queryFactory: QueryFactory, timeouts: Map[String, Duration]) = this(queryFactory, timeouts, 0.millis) - - def apply(connection: Connection, query: String, params: Any*) = { - new TimingOutQuery(queryFactory(connection, query, params: _*), connection, timeouts(query), cancelTimeout) + def apply(connection: Connection, queryClass: QueryClass, query: String, params: Any*) = { + val (timeout, cancelOnTimeout) = timeouts(queryClass) + new TimingOutQuery(queryFactory(connection, queryClass, query, params: _*), connection, timeout, cancelOnTimeout) } } @@ -51,21 +50,21 @@ private object QueryCancellation { *

Note that the query timing out promptly is based upon {@link java.sql.Statement#cancel} * working and executing promptly for the JDBC driver in use. */ -class TimingOutQuery(query: Query, connection: Connection, timeout: Duration, cancelTimeout: Duration) +class TimingOutQuery(query: Query, connection: Connection, timeout: Duration, cancelOnTimeout: Boolean) extends QueryProxy(query) with ConnectionDestroying { + def this(query: Query, connection: Connection, timeout: Duration) = this(query, connection, timeout, false) + import QueryCancellation._ override def delegate[A](f: => A) = { try { - Timeout(timeout) { - f - } { - cancel() + Timeout(cancelTimer, timeout)(f) { + if (cancelOnTimeout) cancel() + destroyConnection(connection) } } catch { - case e: TimeoutException => - throw new SqlQueryTimeoutException(timeout) + case e: TimeoutException => throw new SqlQueryTimeoutException(timeout) } } @@ -73,14 +72,10 @@ class TimingOutQuery(query: Query, connection: Connection, timeout: Duration, ca val cancelThread = new Thread("query cancellation") { override def run() { try { - Timeout(cancelTimer, cancelTimeout) { - // start by trying the nice way - query.cancel() - } { - // if the cancel times out, destroy the underlying connection - destroyConnection(connection) - } - } catch { case e: TimeoutException => } + // This cancel may block, as it has to connect to the database. + // If the default socket connection timeout has been removed, this thread will run away. + query.cancel() + } catch { case e => () } } } cancelThread.start() diff --git a/src/main/scala/com/twitter/querulous/query/TimingOutStatsCollectingQuery.scala b/src/main/scala/com/twitter/querulous/query/TimingOutStatsCollectingQuery.scala deleted file mode 100644 index c1ec893..0000000 --- a/src/main/scala/com/twitter/querulous/query/TimingOutStatsCollectingQuery.scala +++ /dev/null @@ -1,67 +0,0 @@ -package com.twitter.querulous.query - -import java.sql.{Connection, ResultSet} -import scala.collection.Map -import scala.util.matching.Regex -import scala.collection.Map -import com.twitter.xrayspecs.Duration -import com.twitter.xrayspecs.TimeConversions._ -import net.lag.extensions._ - - -object TimingOutStatsCollectingQueryFactory { - val TABLE_NAME = """(FROM|UPDATE|INSERT INTO|LIMIT)\s+[\w-]+""".r - val DDL_QUERY = """^\s*((CREATE|DROP|ALTER)\s+(TABLE|DATABASE)|DESCRIBE)\s+""".r - - def simplifiedQuery(query: String) = { - if (DDL_QUERY.findFirstMatchIn(query).isDefined) { - "default" - } else { - query.regexSub(TABLE_NAME) { m => m.group(1) + "?" } - } - } -} - -class TimingOutStatsCollectingQueryFactory(queryFactory: QueryFactory, - queryInfo: Map[String, (String, Duration)], - defaultTimeout: Duration, cancelTimeout: Duration, stats: StatsCollector) - extends QueryFactory { - - def this(queryFactory: QueryFactory, queryInfo: Map[String, (String, Duration)], defaultTimeout: Duration, stats: StatsCollector) = - this(queryFactory, queryInfo, defaultTimeout, 0.millis, stats) - - def apply(connection: Connection, query: String, params: Any*) = { - val simplifiedQueryString = TimingOutStatsCollectingQueryFactory.simplifiedQuery(query) - val (name, timeout) = queryInfo.getOrElse(simplifiedQueryString, ("default", defaultTimeout)) - - new TimingOutStatsCollectingQuery( - new TimingOutQuery( - queryFactory(connection, query, params: _*), - connection, - timeout, - cancelTimeout), - name, - stats) - } -} - -class TimingOutStatsCollectingQuery(query: Query, queryName: String, stats: StatsCollector) extends QueryProxy(query) { - override def select[A](f: ResultSet => A) = { - stats.incr("db-select-count", 1) - stats.time("db-select-timing")(delegate(query.select(f))) - } - - override def execute() = { - stats.incr("db-execute-count", 1) - stats.time("db-execute-timing")(delegate(query.execute())) - } - - override def delegate[A](f: => A) = { - stats.incr("db-query-count-" + queryName, 1) - stats.time("db-timing") { - stats.time("x-db-query-timing-" + queryName) { - f - } - } - } -} diff --git a/src/main/scala/com/twitter/querulous/test/FakeDatabase.scala b/src/main/scala/com/twitter/querulous/test/FakeDatabase.scala index 8f5743e..635ad73 100644 --- a/src/main/scala/com/twitter/querulous/test/FakeDatabase.scala +++ b/src/main/scala/com/twitter/querulous/test/FakeDatabase.scala @@ -1,17 +1,20 @@ package com.twitter.querulous.test import java.sql.Connection -import com.twitter.xrayspecs.{Duration, Time} -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.util.{Duration, Time} +import com.twitter.util.TimeConversions._ import com.twitter.querulous.database.Database -class FakeDatabase(connection: Connection, latency: Duration) extends Database { +class FakeDatabase(connection: Connection, before: Option[String => Unit]) extends Database { + def this(connection: Connection) = this(connection, None) + def this(connection: Connection, before: String => Unit) = this(connection, Some(before)) + def open(): Connection = { - Time.advance(latency) + before.foreach { _("open") } connection } def close(connection: Connection) { - Time.advance(latency) + before.foreach { _("close") } } } diff --git a/src/main/scala/com/twitter/querulous/test/FakeQuery.scala b/src/main/scala/com/twitter/querulous/test/FakeQuery.scala index a243ceb..b746244 100644 --- a/src/main/scala/com/twitter/querulous/test/FakeQuery.scala +++ b/src/main/scala/com/twitter/querulous/test/FakeQuery.scala @@ -12,4 +12,6 @@ class FakeQuery(resultSets: Seq[ResultSet]) extends Query { } override def execute() = 0 + + def addParams(params: Any*) = {} } diff --git a/src/main/scala/com/twitter/querulous/test/FakeQueryEvaluator.scala b/src/main/scala/com/twitter/querulous/test/FakeQueryEvaluator.scala index f938080..17d60f2 100644 --- a/src/main/scala/com/twitter/querulous/test/FakeQueryEvaluator.scala +++ b/src/main/scala/com/twitter/querulous/test/FakeQueryEvaluator.scala @@ -1,15 +1,18 @@ package com.twitter.querulous.test import java.sql.ResultSet -import com.twitter.xrayspecs.Duration -import com.twitter.querulous.evaluator.{Transaction, QueryEvaluator} +import com.twitter.util.Duration +import com.twitter.querulous.evaluator.{Transaction, QueryEvaluator, ParamsApplier} +import com.twitter.querulous.query.{QueryClass, Query} + class FakeQueryEvaluator[A](trans: Transaction, resultSets: Seq[ResultSet]) extends QueryEvaluator { - def select[A](query: String, params: Any*)(f: ResultSet => A) = resultSets.map(f) - def selectOne[A](query: String, params: Any*)(f: ResultSet => A) = None - def count(query: String, params: Any*) = 0 - def execute(query: String, params: Any*) = 0 + def select[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = resultSets.map(f) + def selectOne[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = None + def count(queryClass: QueryClass, query: String, params: Any*) = 0 + def execute(queryClass: QueryClass, query: String, params: Any*) = 0 + def executeBatch(queryClass: QueryClass, query: String)(f: ParamsApplier => Unit) = 0 def nextId(tableName: String) = 0 - def insert(query: String, params: Any*) = 0 + def insert(queryClass: QueryClass, query: String, params: Any*) = 0 def transaction[T](f: Transaction => T) = f(trans) } diff --git a/src/main/scala/com/twitter/querulous/test/FakeStatsCollector.scala b/src/main/scala/com/twitter/querulous/test/FakeStatsCollector.scala index c2777ba..255ae23 100644 --- a/src/main/scala/com/twitter/querulous/test/FakeStatsCollector.scala +++ b/src/main/scala/com/twitter/querulous/test/FakeStatsCollector.scala @@ -1,22 +1,24 @@ package com.twitter.querulous.test import scala.collection.mutable.Map -import com.twitter.xrayspecs.Time -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.util.Time +import com.twitter.util.TimeConversions._ +import com.twitter.querulous.StatsCollector + class FakeStatsCollector extends StatsCollector { val counts = Map[String, Int]() val times = Map[String, Long]() def incr(name: String, count: Int) = { - counts + (name -> (count+counts.getOrElseUpdate(name, 0))) + counts += (name -> (count+counts.getOrElseUpdate(name, 0))) } def time[A](name: String)(f: => A): A = { val start = Time.now val rv = f val end = Time.now - times + (name -> ((end-start).inMillis + times.getOrElseUpdate(name, 0L))) + times += (name -> ((end-start).inMillis + times.getOrElseUpdate(name, 0L))) rv } } diff --git a/src/test/scala/com/twitter/querulous/TestEvaluator.scala b/src/test/scala/com/twitter/querulous/TestEvaluator.scala index 068c0f9..cd44fe5 100644 --- a/src/test/scala/com/twitter/querulous/TestEvaluator.scala +++ b/src/test/scala/com/twitter/querulous/TestEvaluator.scala @@ -1,16 +1,58 @@ package com.twitter.querulous -import net.lag.configgy.Configgy import com.twitter.querulous.database.{MemoizingDatabaseFactory, SingleConnectionDatabaseFactory} import com.twitter.querulous.query.SqlQueryFactory -import com.twitter.querulous.evaluator.StandardQueryEvaluatorFactory -import com.twitter.xrayspecs.Time -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.querulous.evaluator.{QueryEvaluator, StandardQueryEvaluatorFactory} +import com.twitter.util.Eval +import com.twitter.util.Time +import com.twitter.util.TimeConversions._ +import java.io.File +import java.util.concurrent.CountDownLatch +import org.specs.Specification + +import config.Connection + + +trait ConfiguredSpecification extends Specification { + val config = try { + Eval[Connection](new File("config/test.scala")) + } catch { + case e => + e.printStackTrace() + throw e + } +} object TestEvaluator { // val testDatabaseFactory = new MemoizingDatabaseFactory() val testDatabaseFactory = new SingleConnectionDatabaseFactory val testQueryFactory = new SqlQueryFactory val testEvaluatorFactory = new StandardQueryEvaluatorFactory(testDatabaseFactory, testQueryFactory) + + private val userEnv = System.getenv("DB_USERNAME") + private val passEnv = System.getenv("DB_PASSWORD") + + def getDbLock(queryEvaluator: QueryEvaluator, lockName: String) = { + val returnLatch = new CountDownLatch(1) + val releaseLatch = new CountDownLatch(1) + + val thread = new Thread() { + override def run() { + queryEvaluator.select("SELECT GET_LOCK('" + lockName + "', 1) AS rv") { row => + returnLatch.countDown() + try { + releaseLatch.await() + } catch { + case _ => + } + } + } + } + + thread.start() + returnLatch.await() + + releaseLatch + } } diff --git a/src/test/scala/com/twitter/querulous/TestRunner.scala b/src/test/scala/com/twitter/querulous/TestRunner.scala deleted file mode 100644 index 83a18cf..0000000 --- a/src/test/scala/com/twitter/querulous/TestRunner.scala +++ /dev/null @@ -1,14 +0,0 @@ -package com.twitter.querulous - -import org.specs.runner.SpecsFileRunner -import org.specs.util.Configuration -import net.lag.configgy.Configgy - -object TestRunner extends SpecsFileRunner("src/test/scala/**/*.scala", ".*", - System.getProperty("system", ".*"), System.getProperty("example", ".*")) { - - System.setProperty("stage", "test") - - Configgy.configure(System.getProperty("basedir") + "/config/" + System.getProperty("stage", "test") + ".conf") -} - diff --git a/src/test/scala/com/twitter/querulous/integration/QuerySpec.scala b/src/test/scala/com/twitter/querulous/integration/QuerySpec.scala index c9f572b..b177162 100644 --- a/src/test/scala/com/twitter/querulous/integration/QuerySpec.scala +++ b/src/test/scala/com/twitter/querulous/integration/QuerySpec.scala @@ -1,22 +1,29 @@ package com.twitter.querulous.integration -import org.specs.Specification -import net.lag.configgy.Configgy -import com.twitter.xrayspecs.Time -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.util.Time +import com.twitter.util.TimeConversions._ +import com.twitter.querulous.ConfiguredSpecification +import com.twitter.querulous.TestEvaluator import com.twitter.querulous.database.ApachePoolingDatabaseFactory import com.twitter.querulous.query._ import com.twitter.querulous.evaluator.{StandardQueryEvaluatorFactory, QueryEvaluator} +class QuerySpec extends ConfiguredSpecification { +// Configgy.configure("config/" + System.getProperty("stage", "test") + ".conf") -class QuerySpec extends Specification { - Configgy.configure("config/" + System.getProperty("stage", "test") + ".conf") - +// val config = Configgy.config.configMap("db") +// val username = config("username") +// val password = config("password") +// val queryEvaluator = testEvaluatorFactory("localhost", "db_test", username, password) import TestEvaluator._ - val config = Configgy.config.configMap("db") + val queryEvaluator = testEvaluatorFactory(config) "Query" should { - val queryEvaluator = testEvaluatorFactory(config) + doBefore { + queryEvaluator.execute("CREATE TABLE IF NOT EXISTS foo(bar INT, baz INT)") + queryEvaluator.execute("TRUNCATE foo") + queryEvaluator.execute("INSERT INTO foo VALUES (1,1), (3,3)") + } "with too many arguments" >> { queryEvaluator.select("SELECT 1 FROM DUAL WHERE 1 IN (?)", 1, 2, 3) { r => 1 } must throwA[TooManyQueryParametersException] @@ -26,15 +33,22 @@ class QuerySpec extends Specification { queryEvaluator.select("SELECT 1 FROM DUAL WHERE 1 = ? OR 1 = ?", 1) { r => 1 } must throwA[TooFewQueryParametersException] } + "in batch mode" >> { + queryEvaluator.executeBatch("UPDATE foo SET bar = ? WHERE bar = ?") { withParams => + withParams("2", "1") + withParams("3", "3") + } mustEqual 2 + } + "with just the right number of arguments" >> { queryEvaluator.select("SELECT 1 FROM DUAL WHERE 1 IN (?)", List(1, 2, 3))(_.getInt(1)).toList mustEqual List(1) } "be backwards compatible" >> { - val noOpts = testEvaluatorFactory("localhost", null, config("username"), config("password")) + val noOpts = testEvaluatorFactory(config.hostnames.toList, null, config.username, config.password) noOpts.select("SELECT 1 FROM DUAL WHERE 1 IN (?)", List(1, 2, 3))(_.getInt(1)).toList mustEqual List(1) - val noDBNameOrOpts = testEvaluatorFactory("localhost", config("username"), config("password")) + val noDBNameOrOpts = testEvaluatorFactory(config.hostnames.toList, config.username, config.password) noDBNameOrOpts.select("SELECT 1 FROM DUAL WHERE 1 IN (?)", List(1, 2, 3))(_.getInt(1)).toList mustEqual List(1) } } diff --git a/src/test/scala/com/twitter/querulous/integration/TimeoutSpec.scala b/src/test/scala/com/twitter/querulous/integration/TimeoutSpec.scala index d177f7d..dbe92bb 100644 --- a/src/test/scala/com/twitter/querulous/integration/TimeoutSpec.scala +++ b/src/test/scala/com/twitter/querulous/integration/TimeoutSpec.scala @@ -1,51 +1,43 @@ package com.twitter.querulous.integration import java.util.concurrent.CountDownLatch -import org.specs.Specification -import net.lag.configgy.Configgy -import com.twitter.xrayspecs.Time -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.util.Time +import com.twitter.util.TimeConversions._ +import com.twitter.querulous.TestEvaluator import com.twitter.querulous.database.ApachePoolingDatabaseFactory import com.twitter.querulous.query.{SqlQueryFactory, TimingOutQueryFactory, SqlQueryTimeoutException} import com.twitter.querulous.evaluator.{StandardQueryEvaluatorFactory, QueryEvaluator} +import com.twitter.querulous.ConfiguredSpecification -class TimeoutSpec extends Specification { - Configgy.configure("config/test.conf") - +class TimeoutSpec extends ConfiguredSpecification { import TestEvaluator._ - val config = Configgy.config.configMap("db") - val username = config("username") - val password = config("password") val timeout = 1.second - val timingOutQueryFactory = new TimingOutQueryFactory(testQueryFactory, timeout) + val timingOutQueryFactory = new TimingOutQueryFactory(testQueryFactory, timeout, false) + val apacheDatabaseFactory = new ApachePoolingDatabaseFactory(10, 10, 1.second, 10.millis, false, 0.seconds) val timingOutQueryEvaluatorFactory = new StandardQueryEvaluatorFactory(testDatabaseFactory, timingOutQueryFactory) "Timeouts" should { doBefore { - testEvaluatorFactory("localhost", null, username, password).execute("CREATE DATABASE IF NOT EXISTS db_test") + testEvaluatorFactory(config.withoutDatabase).execute("CREATE DATABASE IF NOT EXISTS db_test") } "honor timeouts" in { - val queryEvaluator1 = testEvaluatorFactory(List("localhost"), "db_test", username, password) - val latch = new CountDownLatch(1) + val queryEvaluator1 = testEvaluatorFactory(config) + val dbLock = getDbLock(queryEvaluator1, "padlock") + val thread = new Thread() { override def run() { - queryEvaluator1.select("SELECT GET_LOCK('padlock', 1) AS rv") { row => - latch.countDown() - try { - Thread.sleep(60.seconds.inMillis) - } catch { - case _ => - } - } + try { + Thread.sleep(60.seconds.inMillis) + } catch { case _ => () } + dbLock.countDown() } } thread.start() - latch.await() - val queryEvaluator2 = timingOutQueryEvaluatorFactory(List("localhost"), "db_test", username, password) + val queryEvaluator2 = timingOutQueryEvaluatorFactory(config) val start = Time.now queryEvaluator2.select("SELECT GET_LOCK('padlock', 60) AS rv") { row => row.getInt("rv") } must throwA[SqlQueryTimeoutException] val end = Time.now diff --git a/src/test/scala/com/twitter/querulous/unit/AutoDisablingQueryEvaluatorSpec.scala b/src/test/scala/com/twitter/querulous/unit/AutoDisablingQueryEvaluatorSpec.scala index 119dd5a..4430596 100644 --- a/src/test/scala/com/twitter/querulous/unit/AutoDisablingQueryEvaluatorSpec.scala +++ b/src/test/scala/com/twitter/querulous/unit/AutoDisablingQueryEvaluatorSpec.scala @@ -6,8 +6,8 @@ import java.sql.{ResultSet, SQLException, SQLIntegrityConstraintViolationExcepti import com.mysql.jdbc.exceptions.MySQLIntegrityConstraintViolationException import com.twitter.querulous.test.FakeQueryEvaluator import com.twitter.querulous.evaluator.{AutoDisablingQueryEvaluator, Transaction} -import com.twitter.xrayspecs.Time -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.util.Time +import com.twitter.util.TimeConversions._ class AutoDisablingQueryEvaluatorSpec extends Specification with JMocker with ClassMocker { @@ -18,8 +18,6 @@ class AutoDisablingQueryEvaluatorSpec extends Specification with JMocker with Cl val disableDuration = 1.minute val queryEvaluator = new FakeQueryEvaluator(trans, List(mock[ResultSet])) val autoDisablingQueryEvaluator = new AutoDisablingQueryEvaluator(queryEvaluator, disableErrorCount, disableDuration) - Time.freeze() - "when there are no failures" >> { autoDisablingQueryEvaluator.select("SELECT 1 FROM DUAL") { _ => 1 } mustEqual List(1) } @@ -60,21 +58,23 @@ class AutoDisablingQueryEvaluatorSpec extends Specification with JMocker with Cl } "when there are more than disableErrorCount failures but disableDuration has elapsed" >> { - var invocationCount = 0 + Time.withCurrentTimeFrozen { time => + var invocationCount = 0 - (0 until disableErrorCount + 1) foreach { i => + (0 until disableErrorCount + 1) foreach { i => + autoDisablingQueryEvaluator.select("SELECT 1 FROM DUAL") { resultSet => + invocationCount += 1 + throw new SQLException + } must throwA[SQLException] + } + invocationCount mustEqual disableErrorCount + + time.advance(1.minute) autoDisablingQueryEvaluator.select("SELECT 1 FROM DUAL") { resultSet => invocationCount += 1 - throw new SQLException - } must throwA[SQLException] - } - invocationCount mustEqual disableErrorCount - - Time.advance(1.minute) - autoDisablingQueryEvaluator.select("SELECT 1 FROM DUAL") { resultSet => - invocationCount += 1 + } + invocationCount mustEqual disableErrorCount + 1 } - invocationCount mustEqual disableErrorCount + 1 } } } diff --git a/src/test/scala/com/twitter/querulous/unit/DatabaseSpec.scala b/src/test/scala/com/twitter/querulous/unit/DatabaseSpec.scala new file mode 100644 index 0000000..8293d69 --- /dev/null +++ b/src/test/scala/com/twitter/querulous/unit/DatabaseSpec.scala @@ -0,0 +1,71 @@ +package com.twitter.querulous.unit + +import java.sql.{PreparedStatement, Connection, Types} +import org.apache.commons.dbcp.{DelegatingConnection => DBCPConnection} +import com.mysql.jdbc.{ConnectionImpl => MySQLConnection} +import java.util.Properties +import org.specs.mock.{ClassMocker, JMocker} +import com.twitter.util.TimeConversions._ +import com.twitter.querulous.database._ +import com.twitter.querulous.ConfiguredSpecification + + +class DatabaseSpec extends ConfiguredSpecification with JMocker with ClassMocker { + val defaultProps = Map("socketTimeout" -> "41", "connectTimeout" -> "42") + + def mysqlConn(conn: Connection) = conn match { + case c: DBCPConnection => + c.getInnermostDelegate.asInstanceOf[MySQLConnection] + case c: MySQLConnection => c + } + + def testFactory(factory: DatabaseFactory) { + "allow specification of default query options" in { + val db = factory(config.hostnames.toList, null, config.username, config.password) + val props = mysqlConn(db.open).getProperties + + props.getProperty("connectTimeout") mustEqual "42" + props.getProperty("socketTimeout") mustEqual "41" + } + + "allow override of default query options" in { + val db = factory( + config.hostnames.toList, + null, + config.username, + config.password, + Map("connectTimeout" -> "43")) + val props = mysqlConn(db.open).getProperties + + props.getProperty("connectTimeout") mustEqual "43" + props.getProperty("socketTimeout") mustEqual "41" + } + } + + "SingleConnectionDatabaseFactory" should { + val factory = new SingleConnectionDatabaseFactory(defaultProps) + testFactory(factory) + } + + "ApachePoolingDatabaseFactory" should { + val factory = new ApachePoolingDatabaseFactory( + 10, 10, 1.second, 10.millis, false, 0.seconds, defaultProps + ) + + testFactory(factory) + } + + "Database#url" should { + val fake = new Object with Database { + def open() = null + def close(connection: Connection) = () + }.asInstanceOf[{def url(a: List[String], b:String, c:Map[String, String]): String}] + + "add default unicode urlOptions" in { + val url = fake.url(List("host"), "db", Map()) + + url mustMatch "useUnicode=true" + url mustMatch "characterEncoding=UTF-8" + } + } +} diff --git a/src/test/scala/com/twitter/querulous/unit/MemoizingDatabaseFactorySpec.scala b/src/test/scala/com/twitter/querulous/unit/MemoizingDatabaseFactorySpec.scala index f5a61aa..701ba80 100644 --- a/src/test/scala/com/twitter/querulous/unit/MemoizingDatabaseFactorySpec.scala +++ b/src/test/scala/com/twitter/querulous/unit/MemoizingDatabaseFactorySpec.scala @@ -18,8 +18,8 @@ class MemoizingDatabaseFactorySpec extends Specification with JMocker { val memoizingDatabase = new MemoizingDatabaseFactory(databaseFactory) expect { - one(databaseFactory).apply(hosts, "bar", username, password, null) willReturn database1 - one(databaseFactory).apply(hosts, "baz", username, password, null) willReturn database2 + one(databaseFactory).apply(hosts, "bar", username, password, Map.empty) willReturn database1 + one(databaseFactory).apply(hosts, "baz", username, password, Map.empty) willReturn database2 } memoizingDatabase(hosts, "bar", username, password) mustBe database1 memoizingDatabase(hosts, "bar", username, password) mustBe database1 diff --git a/src/test/scala/com/twitter/querulous/unit/QueryEvaluatorSpec.scala b/src/test/scala/com/twitter/querulous/unit/QueryEvaluatorSpec.scala index 9497d53..6099a77 100644 --- a/src/test/scala/com/twitter/querulous/unit/QueryEvaluatorSpec.scala +++ b/src/test/scala/com/twitter/querulous/unit/QueryEvaluatorSpec.scala @@ -2,30 +2,24 @@ package com.twitter.querulous.unit import java.sql.{SQLException, DriverManager, Connection} import scala.collection.mutable -import net.lag.configgy.{Config, Configgy} import com.mysql.jdbc.exceptions.MySQLIntegrityConstraintViolationException +import com.twitter.querulous.{StatsCollector, TestEvaluator} import com.twitter.querulous.database.{ApachePoolingDatabaseFactory, MemoizingDatabaseFactory, Database} import com.twitter.querulous.evaluator.{StandardQueryEvaluator, StandardQueryEvaluatorFactory, QueryEvaluator} import com.twitter.querulous.query._ import com.twitter.querulous.test.FakeDatabase -import com.twitter.xrayspecs.Time -import com.twitter.xrayspecs.TimeConversions._ -import org.specs.Specification +import com.twitter.querulous.ConfiguredSpecification +import com.twitter.util.Time +import com.twitter.util.TimeConversions._ import org.specs.mock.{ClassMocker, JMocker} -class QueryEvaluatorSpec extends Specification with JMocker with ClassMocker { - Configgy.configure("config/test.conf") +class QueryEvaluatorSpec extends ConfiguredSpecification with JMocker with ClassMocker { import TestEvaluator._ - val config = Configgy.config.configMap("db") - val username = config("username") - val password = config("password") - val urlOptions = config.configMap("url_options").asMap.asInstanceOf[Map[String, String]] - "QueryEvaluator" should { - val queryEvaluator = testEvaluatorFactory("localhost", "db_test", username, password, urlOptions) - val rootQueryEvaluator = testEvaluatorFactory("localhost", null, username, password, urlOptions) + val queryEvaluator = testEvaluatorFactory(config) + val rootQueryEvaluator = testEvaluatorFactory(config.withoutDatabase) val queryFactory = new SqlQueryFactory doBefore { @@ -36,27 +30,9 @@ class QueryEvaluatorSpec extends Specification with JMocker with ClassMocker { queryEvaluator.execute("DROP TABLE IF EXISTS foo") } - "fromConfig" in { - val stats = mock[StatsCollector] - QueryFactory.fromConfig(Config.fromMap(Map.empty), None) must haveClass[SqlQueryFactory] - QueryFactory.fromConfig(Config.fromMap(Map.empty), Some(stats)) must - haveClass[StatsCollectingQueryFactory] - QueryFactory.fromConfig(Config.fromMap(Map("query_timeout_default" -> "10")), None) must - haveClass[TimingOutQueryFactory] - QueryFactory.fromConfig(Config.fromMap(Map("retries" -> "10")), None) must - haveClass[RetryingQueryFactory] - QueryFactory.fromConfig(Config.fromMap(Map("debug" -> "true")), None) must - haveClass[DebuggingQueryFactory] - - val config = new Config() - config.setConfigMap("queries", new Config()) - config("query_timeout_default") = "10" - QueryFactory.fromConfig(config, Some(stats)) must haveClass[TimingOutStatsCollectingQueryFactory] - } - "connection pooling" in { val connection = mock[Connection] - val database = new FakeDatabase(connection, 1.millis) + val database = new FakeDatabase(connection) "transactionally" >> { val queryEvaluator = new StandardQueryEvaluator(database, queryFactory) @@ -95,7 +71,8 @@ class QueryEvaluatorSpec extends Specification with JMocker with ClassMocker { "fallback to a read slave" in { // should always succeed if you have the right mysql driver. - val queryEvaluator = testEvaluatorFactory(List("localhost:12349", "localhost"), "db_test", username, password) + val queryEvaluator = testEvaluatorFactory( + "localhost:12349" :: config.hostnames.toList, config.database, config.username, config.password) queryEvaluator.selectOne("SELECT 1") { row => row.getInt(1) }.toList mustEqual List(1) queryEvaluator.execute("CREATE TABLE foo (id INT)") must throwA[SQLException] } diff --git a/src/test/scala/com/twitter/querulous/unit/SqlQuerySpec.scala b/src/test/scala/com/twitter/querulous/unit/SqlQuerySpec.scala index ae9e5bc..0515022 100644 --- a/src/test/scala/com/twitter/querulous/unit/SqlQuerySpec.scala +++ b/src/test/scala/com/twitter/querulous/unit/SqlQuerySpec.scala @@ -1,12 +1,11 @@ package com.twitter.querulous.unit -import java.sql.{PreparedStatement, Connection, Types} +import java.sql.{PreparedStatement, Connection, Types, SQLException} import org.specs.Specification import org.specs.mock.{ClassMocker, JMocker} import com.twitter.querulous.query.NullValues._ import com.twitter.querulous.query.{NullValues, SqlQuery} - class SqlQuerySpec extends Specification with JMocker with ClassMocker { "SqlQuery" should { "typecast" in { @@ -23,6 +22,50 @@ class SqlQuerySpec extends Specification with JMocker with ClassMocker { } new SqlQuery(connection, "SELECT * FROM foo WHERE id IN (?)", List(1, 2, 3)).select { _ => 1 } } + + "arrays of pairs" in { + val connection = mock[Connection] + val statement = mock[PreparedStatement] + expect { + one(connection).prepareStatement("SELECT * FROM foo WHERE (id, uid) IN ((?,?),(?,?))") willReturn statement + one(statement).setInt(1, 1) then + one(statement).setInt(2, 2) then + one(statement).setInt(3, 3) then + one(statement).setInt(4, 4) then + one(statement).executeQuery() then + one(statement).getResultSet + } + new SqlQuery(connection, "SELECT * FROM foo WHERE (id, uid) IN (?)", List((1, 2), (3, 4))).select { _ => 1 } + } + + "arrays of tuple3s" in { + val connection = mock[Connection] + val statement = mock[PreparedStatement] + expect { + one(connection).prepareStatement("SELECT * FROM foo WHERE (id1, id2, id3) IN ((?,?,?))") willReturn statement + one(statement).setInt(1, 1) then + one(statement).setInt(2, 2) then + one(statement).setInt(3, 3) then + one(statement).executeQuery() then + one(statement).getResultSet + } + new SqlQuery(connection, "SELECT * FROM foo WHERE (id1, id2, id3) IN (?)", List((1, 2, 3))).select { _ => 1 } + } + + "arrays of tuple4s" in { + val connection = mock[Connection] + val statement = mock[PreparedStatement] + expect { + one(connection).prepareStatement("SELECT * FROM foo WHERE (id1, id2, id3, id4) IN ((?,?,?,?))") willReturn statement + one(statement).setInt(1, 1) then + one(statement).setInt(2, 2) then + one(statement).setInt(3, 3) then + one(statement).setInt(4, 4) then + one(statement).executeQuery() then + one(statement).getResultSet + } + new SqlQuery(connection, "SELECT * FROM foo WHERE (id1, id2, id3, id4) IN (?)", List((1, 2, 3, 4))).select { _ => 1 } + } } "create a query string" in { @@ -57,5 +100,34 @@ class SqlQuerySpec extends Specification with JMocker with ClassMocker { new SqlQuery(connection, queryString, NullString, NullInt, NullDouble, NullBoolean, NullLong, NullValues(Types.VARBINARY)) } + + "handle exceptions" in { + val queryString = "INSERT INTO TABLE (col1) VALUES (?)" + val connection = mock[Connection] + val statement = mock[PreparedStatement] + val unrecognizedType = connection + "throw illegal argument exception if type passed in is unrecognized" in { + expect { + one(connection).prepareStatement(queryString) willReturn statement + } + new SqlQuery(connection, queryString, unrecognizedType) must throwAn[IllegalArgumentException] + } + "throw chained-exception" in { + val expectedCauseException = new SQLException("") + expect { + one(connection).prepareStatement(queryString) willReturn statement then + one(statement).setString(1, "one") willThrow expectedCauseException + } + try { + new SqlQuery(connection, queryString, "one") + fail("should throw") + } catch { + case e: Exception => { + e.getCause must beEqualTo(expectedCauseException) + } + case _ => fail("unknown throwable") + } + } + } } } diff --git a/src/test/scala/com/twitter/querulous/unit/StatsCollectingDatabaseSpec.scala b/src/test/scala/com/twitter/querulous/unit/StatsCollectingDatabaseSpec.scala index ba6d1fe..a4ccab2 100644 --- a/src/test/scala/com/twitter/querulous/unit/StatsCollectingDatabaseSpec.scala +++ b/src/test/scala/com/twitter/querulous/unit/StatsCollectingDatabaseSpec.scala @@ -4,29 +4,47 @@ import scala.collection.mutable.Map import java.sql.Connection import org.specs.Specification import org.specs.mock.{ClassMocker, JMocker} -import com.twitter.querulous.database.StatsCollectingDatabase +import com.twitter.querulous.database.{SqlDatabaseTimeoutException, StatsCollectingDatabase} import com.twitter.querulous.test.{FakeStatsCollector, FakeDatabase} -import com.twitter.xrayspecs.Time -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.util.Time +import com.twitter.util.TimeConversions._ class StatsCollectingDatabaseSpec extends Specification with JMocker with ClassMocker { "StatsCollectingDatabase" should { - Time.freeze() val latency = 1.second val connection = mock[Connection] val stats = new FakeStatsCollector - val pool = new StatsCollectingDatabase(new FakeDatabase(connection, latency), stats) + def pool(callback: String => Unit) = new StatsCollectingDatabase(new FakeDatabase(connection, callback), stats) "collect stats" in { "when closing" >> { - pool.close(connection) - stats.times("database-close-timing") mustEqual latency.inMillis + Time.withCurrentTimeFrozen { time => + pool(s => time.advance(latency)).close(connection) + stats.times("db-close-timing") mustEqual latency.inMillis + } } "when opening" >> { - pool.open() - stats.times("database-open-timing") mustEqual latency.inMillis + Time.withCurrentTimeFrozen { time => + pool(s => time.advance(latency)).open() + stats.times("db-open-timing") mustEqual latency.inMillis + } + } + } + + "collect timeout stats" in { + val e = new SqlDatabaseTimeoutException("foo", 0.seconds) + "when closing" >> { + pool(s => throw e).close(connection) must throwA[SqlDatabaseTimeoutException] + stats.counts("db-close-timeout-count") mustEqual 1 + } + + "when opening" >> { + Time.withCurrentTimeFrozen { time => + pool(s => throw e).open() must throwA[SqlDatabaseTimeoutException] + stats.counts("db-open-timeout-count") mustEqual 1 + } } } } diff --git a/src/test/scala/com/twitter/querulous/unit/StatsCollectingQuerySpec.scala b/src/test/scala/com/twitter/querulous/unit/StatsCollectingQuerySpec.scala index f062da8..c8d4201 100644 --- a/src/test/scala/com/twitter/querulous/unit/StatsCollectingQuerySpec.scala +++ b/src/test/scala/com/twitter/querulous/unit/StatsCollectingQuerySpec.scala @@ -3,31 +3,45 @@ package com.twitter.querulous.unit import java.sql.ResultSet import org.specs.Specification import org.specs.mock.JMocker -import com.twitter.querulous.query.StatsCollectingQuery +import com.twitter.querulous.query.{QueryClass, SqlQueryTimeoutException, StatsCollectingQuery} import com.twitter.querulous.test.{FakeQuery, FakeStatsCollector} -import com.twitter.xrayspecs.Time -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.util.Time +import com.twitter.util.TimeConversions._ class StatsCollectingQuerySpec extends Specification with JMocker { "StatsCollectingQuery" should { - Time.freeze() - val latency = 1.second - val testQuery = new FakeQuery(List(mock[ResultSet])) { - override def select[A](f: ResultSet => A) = { - Time.advance(latency) - super.select(f) + "collect stats" in { + Time.withCurrentTimeFrozen { time => + val latency = 1.second + val stats = new FakeStatsCollector + val testQuery = new FakeQuery(List(mock[ResultSet])) { + override def select[A](f: ResultSet => A) = { + time.advance(latency) + super.select(f) + } + } + val statsCollectingQuery = new StatsCollectingQuery(testQuery, QueryClass.Select, stats) + + statsCollectingQuery.select { _ => 1 } mustEqual List(1) + + stats.counts("db-select-count") mustEqual 1 + stats.times("db-timing") mustEqual latency.inMillis } } - "collect stats" in { - val stats = new FakeStatsCollector - val statsCollectingQuery = new StatsCollectingQuery(testQuery, stats) + "collect timeout stats" in { + Time.withCurrentTimeFrozen { time => + val stats = new FakeStatsCollector + val testQuery = new FakeQuery(List(mock[ResultSet])) + val statsCollectingQuery = new StatsCollectingQuery(testQuery, QueryClass.Select, stats) + val e = new SqlQueryTimeoutException(0.seconds) - statsCollectingQuery.select { _ => 1 } mustEqual List(1) + statsCollectingQuery.select { _ => throw e } must throwA[SqlQueryTimeoutException] - stats.counts("db-select-count") mustEqual 1 - stats.times("db-timing") mustEqual latency.inMillis + stats.counts("db-query-timeout-count") mustEqual 1 + stats.counts("db-query-" + QueryClass.Select.name + "-timeout-count") mustEqual 1 + } } } } diff --git a/src/test/scala/com/twitter/querulous/unit/TimingOutDatabaseSpec.scala b/src/test/scala/com/twitter/querulous/unit/TimingOutDatabaseSpec.scala index 0785194..b1e8f3a 100644 --- a/src/test/scala/com/twitter/querulous/unit/TimingOutDatabaseSpec.scala +++ b/src/test/scala/com/twitter/querulous/unit/TimingOutDatabaseSpec.scala @@ -1,22 +1,23 @@ package com.twitter.querulous.unit +import com.twitter.querulous._ import java.util.concurrent.{CountDownLatch, TimeUnit} import java.sql.Connection import com.twitter.querulous.TimeoutException import com.twitter.querulous.database.{SqlDatabaseTimeoutException, TimingOutDatabase, Database} -import com.twitter.xrayspecs.Time -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.util.Time +import com.twitter.util.TimeConversions._ import org.specs.Specification import org.specs.mock.{JMocker, ClassMocker} class TimingOutDatabaseSpec extends Specification with JMocker with ClassMocker { "TimingOutDatabaseSpec" should { - Time.reset() val latch = new CountDownLatch(1) val timeout = 1.second var shouldWait = false val connection = mock[Connection] + val future = new FutureTimeout(1, 1) val database = new Database { def open() = { if (shouldWait) latch.await(100.seconds.inMillis, TimeUnit.MILLISECONDS) @@ -29,7 +30,7 @@ class TimingOutDatabaseSpec extends Specification with JMocker with ClassMocker // one(connection).close() } - val timingOutDatabase = new TimingOutDatabase(database, List("dbhost"), "dbname", 1, 1, timeout, timeout, 1) + val timingOutDatabase = new TimingOutDatabase(database, List("dbhost"), "dbname", future, timeout, 1) shouldWait = true "timeout" in { diff --git a/src/test/scala/com/twitter/querulous/unit/TimingOutQuerySpec.scala b/src/test/scala/com/twitter/querulous/unit/TimingOutQuerySpec.scala index 1b93de8..6bc28cd 100644 --- a/src/test/scala/com/twitter/querulous/unit/TimingOutQuerySpec.scala +++ b/src/test/scala/com/twitter/querulous/unit/TimingOutQuerySpec.scala @@ -1,22 +1,22 @@ package com.twitter.querulous.unit import java.sql.ResultSet -import net.lag.configgy.Configgy import org.specs.Specification import org.specs.mock.{JMocker, ClassMocker} +import com.twitter.querulous.TestEvaluator import com.twitter.querulous.test.FakeQuery import com.twitter.querulous.query.{TimingOutQuery, SqlQueryTimeoutException} -import com.twitter.xrayspecs.Duration -import com.twitter.xrayspecs.TimeConversions._ +import com.twitter.querulous.ConfiguredSpecification +import com.twitter.util.Duration +import com.twitter.util.TimeConversions._ import java.util.concurrent.{CountDownLatch, TimeUnit} -class TimingOutQuerySpec extends Specification with JMocker with ClassMocker { +object TimingOutQuerySpec extends ConfiguredSpecification with JMocker with ClassMocker { "TimingOutQuery" should { - val config = Configgy.config.configMap("db") - val connection = TestEvaluator.testDatabaseFactory(List("localhost"), config("username"), config("password")).open() + val connection = TestEvaluator.testDatabaseFactory( + config.hostnames.toList, config.username, config.password).open() val timeout = 1.second - val cancelTimeout = 0.millis val resultSet = mock[ResultSet] "timeout" in { @@ -29,7 +29,7 @@ class TimingOutQuerySpec extends Specification with JMocker with ClassMocker { super.select(f) } } - val timingOutQuery = new TimingOutQuery(query, connection, timeout, cancelTimeout) + val timingOutQuery = new TimingOutQuery(query, connection, timeout, true) timingOutQuery.select { r => 1 } must throwA[SqlQueryTimeoutException] latch.getCount mustEqual 0 @@ -40,7 +40,7 @@ class TimingOutQuerySpec extends Specification with JMocker with ClassMocker { val query = new FakeQuery(List(resultSet)) { override def cancel() = { latch.countDown() } } - val timingOutQuery = new TimingOutQuery(query, connection, timeout, cancelTimeout) + val timingOutQuery = new TimingOutQuery(query, connection, timeout, true) timingOutQuery.select { r => 1 } latch.getCount mustEqual 1 diff --git a/src/test/scala/com/twitter/querulous/unit/TimingOutStatsCollectingQuerySpec.scala b/src/test/scala/com/twitter/querulous/unit/TimingOutStatsCollectingQuerySpec.scala deleted file mode 100644 index 8bef8d7..0000000 --- a/src/test/scala/com/twitter/querulous/unit/TimingOutStatsCollectingQuerySpec.scala +++ /dev/null @@ -1,54 +0,0 @@ -package com.twitter.querulous.unit - -import java.sql.ResultSet -import org.specs.Specification -import org.specs.mock.{ClassMocker, JMocker} -import com.twitter.querulous.test.{FakeQuery, FakeStatsCollector} -import com.twitter.querulous.query.TimingOutStatsCollectingQuery -import com.twitter.xrayspecs.Time -import com.twitter.xrayspecs.TimeConversions._ - - -class TimingOutStatsCollectingQuerySpec extends Specification with JMocker { - "TimingOutStatsCollectingQuery" should { - Time.freeze() - val latency = 1.second - val testQuery = new FakeQuery(List(mock[ResultSet])) { - override def select[A](f: ResultSet => A) = { - Time.advance(latency) - super.select(f) - } - override def execute() = { - Time.advance(latency) - super.execute() - } - } - - "collect stats" in { - val stats = new FakeStatsCollector - - "selects" >> { - val query = new TimingOutStatsCollectingQuery(testQuery, "selectTest", stats) - query.select { _ => 1 } mustEqual List(1) - stats.times("db-select-timing") mustEqual latency.inMillis - stats.times("x-db-query-timing-selectTest") mustEqual latency.inMillis - stats.counts("db-select-count") mustEqual 1 - } - - "executes" >> { - val query = new TimingOutStatsCollectingQuery(testQuery, "executeTest", stats) - query.execute - stats.times("db-execute-timing") mustEqual latency.inMillis - stats.times("x-db-query-timing-executeTest") mustEqual latency.inMillis - stats.counts("db-execute-count") mustEqual 1 - } - - "globally" >> { - val query = new TimingOutStatsCollectingQuery(testQuery, "default", stats) - query.select { _ => 1 } mustEqual List(1) - stats.times("db-timing") mustEqual latency.inMillis - } - } - } -} -