diff --git a/.gitignore b/.gitignore
index 14c6bac..dafc1ed 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,10 @@
+.idea/
+*.iml
+*.ipr
dist/
target/
lib_managed/
src_managed/
project/boot/
project/plugins/project/
+querulous.tmproj
diff --git a/README.markdown b/README.markdown
index 91e09a1..cde889c 100644
--- a/README.markdown
+++ b/README.markdown
@@ -154,7 +154,7 @@ Add the following dependency and repository stanzas to your project's configurat
twitter.com
- http://www.lag.net/nest
+ http://maven.twttr.com/
### Ivy
@@ -165,7 +165,7 @@ Add the following dependency to ivy.xml
and the following repository to ivysettings.xml
-
+
## Running Tests
diff --git a/config/test.conf b/config/test.conf
deleted file mode 100644
index 6483d9a..0000000
--- a/config/test.conf
+++ /dev/null
@@ -1,9 +0,0 @@
-db {
- hostname = "localhost"
- username = "root"
- password = ""
- url_options {
- useUnicode = "true"
- characterEncoding = "UTF-8"
- }
-}
diff --git a/config/test.scala b/config/test.scala
new file mode 100644
index 0000000..40b5c34
--- /dev/null
+++ b/config/test.scala
@@ -0,0 +1,15 @@
+import com.twitter.querulous.config.Connection
+
+new Connection {
+ val hostnames = Seq("localhost")
+ val database = "db_test"
+ val username = {
+ val userEnv = System.getenv("DB_USERNAME")
+ if (userEnv == null) "root" else userEnv
+ }
+
+ val password = {
+ val passEnv = System.getenv("DB_PASSWORD")
+ if (passEnv == null) "" else passEnv
+ }
+}
diff --git a/project/build.properties b/project/build.properties
index 2da0f90..7e84939 100644
--- a/project/build.properties
+++ b/project/build.properties
@@ -1,8 +1,8 @@
#Project properties
-#Fri Sep 17 11:09:45 PDT 2010
+#Mon Feb 28 13:01:14 PST 2011
project.organization=com.twitter
project.name=querulous
sbt.version=0.7.4
-project.version=1.2.2
-build.scala.versions=2.7.7
+project.version=2.0.2-SNAPSHOT
+build.scala.versions=2.8.1
project.initialize=false
diff --git a/project/build/QuerulousProject.scala b/project/build/QuerulousProject.scala
index 3764ce3..ee9eaee 100644
--- a/project/build/QuerulousProject.scala
+++ b/project/build/QuerulousProject.scala
@@ -1,30 +1,23 @@
import sbt._
+import Process._
import com.twitter.sbt._
-
class QuerulousProject(info: ProjectInfo) extends StandardProject(info) with SubversionPublisher {
- val vscaladoc = "org.scala-tools" % "vscaladoc" % "1.1-md-3"
- val configgy = "net.lag" % "configgy" % "1.5.2"
+ override def filterScalaJars = false
+
+ val util = "com.twitter" % "util" % "1.6.4"
+
val dbcp = "commons-dbcp" % "commons-dbcp" % "1.4"
- val mysqljdbc = "mysql" % "mysql-connector-java" % "5.1.6"
+ val mysqljdbc = "mysql" % "mysql-connector-java" % "5.1.13"
val pool = "commons-pool" % "commons-pool" % "1.5.4"
- val xrayspecs = "com.twitter" % "xrayspecs" % "1.0.7"
-
- val hamcrest = "org.hamcrest" % "hamcrest-all" % "1.1" % "test"
- val specs = "org.scala-tools.testing" % "specs" % "1.6.2.1" % "test"
- val objenesis = "org.objenesis" % "objenesis" % "1.1" % "test"
- val jmock = "org.jmock" % "jmock" % "2.4.0" % "test"
- val cglib = "cglib" % "cglib" % "2.1_3" % "test"
- val asm = "asm" % "asm" % "1.5.3" % "test"
- override def pomExtra =
-
-
- Apache 2
- http://www.apache.org/licenses/LICENSE-2.0.txt
- repo
-
-
+ val scalaTools = "org.scala-lang" % "scala-compiler" % "2.8.1" % "test"
+ val hamcrest = "org.hamcrest" % "hamcrest-all" % "1.1" % "test"
+ val specs = "org.scala-tools.testing" % "specs_2.8.0" % "1.6.5" % "test"
+ val objenesis = "org.objenesis" % "objenesis" % "1.1" % "test"
+ val jmock = "org.jmock" % "jmock" % "2.4.0" % "test"
+ val cglib = "cglib" % "cglib" % "2.1_3" % "test"
+ val asm = "asm" % "asm" % "1.5.3" % "test"
override def subversionRepository = Some("http://svn.local.twitter.com/maven-public/")
}
diff --git a/project/plugins/Plugins.scala b/project/plugins/Plugins.scala
index 9bb719d..7e3ef44 100644
--- a/project/plugins/Plugins.scala
+++ b/project/plugins/Plugins.scala
@@ -1,6 +1,6 @@
import sbt._
class Plugins(info: ProjectInfo) extends PluginDefinition(info) {
- val twitterNest = "com.twitter" at "http://www.lag.net/nest"
- val defaultProject = "com.twitter" % "standard-project" % "0.7.1"
+ val twitter = "twitter.com" at "http://maven.twttr.com/"
+ val defaultProject = "com.twitter" % "standard-project" % "0.7.17"
}
diff --git a/src/main/scala/com/twitter/querulous/AutoDisabler.scala b/src/main/scala/com/twitter/querulous/AutoDisabler.scala
index aefff81..f0d23f7 100644
--- a/src/main/scala/com/twitter/querulous/AutoDisabler.scala
+++ b/src/main/scala/com/twitter/querulous/AutoDisabler.scala
@@ -1,7 +1,7 @@
package com.twitter.querulous
-import com.twitter.xrayspecs.{Time, Duration}
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.util.{Time, Duration}
+import com.twitter.util.TimeConversions._
import java.sql.{SQLException, SQLIntegrityConstraintViolationException}
@@ -9,7 +9,7 @@ trait AutoDisabler {
protected val disableErrorCount: Int
protected val disableDuration: Duration
- private var disabledUntil: Time = Time.never
+ private var disabledUntil: Time = Time.epoch
private var consecutiveErrors = 0
protected def throwIfDisabled(throwMessage: String): Unit = {
diff --git a/src/main/scala/com/twitter/querulous/ConnectionDestroying.scala b/src/main/scala/com/twitter/querulous/ConnectionDestroying.scala
index 907507e..09f010b 100644
--- a/src/main/scala/com/twitter/querulous/ConnectionDestroying.scala
+++ b/src/main/scala/com/twitter/querulous/ConnectionDestroying.scala
@@ -4,12 +4,10 @@ import java.sql.Connection
import org.apache.commons.dbcp.{DelegatingConnection => DBCPConnection}
import com.mysql.jdbc.{ConnectionImpl => MySQLConnection}
-
// Emergency connection destruction toolkit
-
trait ConnectionDestroying {
def destroyConnection(conn: Connection) {
- if ( !conn.isClosed )
+ if (!conn.isClosed)
conn match {
case c: DBCPConnection =>
destroyDbcpWrappedConnection(c)
@@ -22,21 +20,18 @@ trait ConnectionDestroying {
def destroyDbcpWrappedConnection(conn: DBCPConnection) {
val inner = conn.getInnermostDelegate
- if ( inner != null ) {
+ if (inner ne null) {
destroyConnection(inner)
} else {
- // this should never happen if we use our own ApachePoolingDatabase to get connections.
- error("Could not get access to the delegate connection. Make sure the dbcp connection pool allows access to underlying connections.")
+ // might just be a race; move on.
+ return
}
// "close" the wrapper so that it updates its internal bookkeeping, just do it
- try { conn.close } catch { case _ => }
+ try { conn.close() } catch { case _ => }
}
def destroyMysqlConnection(conn: MySQLConnection) {
- val abort = Class.forName("com.mysql.jdbc.ConnectionImpl").getDeclaredMethod("abortInternal")
- abort.setAccessible(true)
-
- abort.invoke(conn)
+ conn.abortInternal()
}
}
diff --git a/src/main/scala/com/twitter/querulous/FutureTimeout.scala b/src/main/scala/com/twitter/querulous/FutureTimeout.scala
index 1975742..931f1de 100644
--- a/src/main/scala/com/twitter/querulous/FutureTimeout.scala
+++ b/src/main/scala/com/twitter/querulous/FutureTimeout.scala
@@ -2,7 +2,7 @@ package com.twitter.querulous
import java.util.concurrent.{ThreadFactory, TimeoutException => JTimeoutException, _}
import java.util.concurrent.atomic.AtomicInteger
-import com.twitter.xrayspecs.Duration
+import com.twitter.util.Duration
class FutureTimeout(poolSize: Int, queueSize: Int) {
object DaemonThreadFactory extends ThreadFactory {
@@ -20,7 +20,7 @@ class FutureTimeout(poolSize: Int, queueSize: Int) {
thread
}
}
- private val executor = new ThreadPoolExecutor(poolSize, poolSize, 0, TimeUnit.SECONDS,
+ private val executor = new ThreadPoolExecutor(1, poolSize, 60, TimeUnit.SECONDS,
new LinkedBlockingQueue[Runnable](queueSize),
DaemonThreadFactory)
diff --git a/src/main/scala/com/twitter/querulous/StatsCollector.scala b/src/main/scala/com/twitter/querulous/StatsCollector.scala
index 7725972..41ccd0a 100644
--- a/src/main/scala/com/twitter/querulous/StatsCollector.scala
+++ b/src/main/scala/com/twitter/querulous/StatsCollector.scala
@@ -4,3 +4,8 @@ trait StatsCollector {
def incr(name: String, count: Int)
def time[A](name: String)(f: => A): A
}
+
+object NullStatsCollector extends StatsCollector {
+ def incr(name: String, count: Int) {}
+ def time[A](name: String)(f: => A): A = f
+}
diff --git a/src/main/scala/com/twitter/querulous/Timeout.scala b/src/main/scala/com/twitter/querulous/Timeout.scala
index 8167d82..365a583 100644
--- a/src/main/scala/com/twitter/querulous/Timeout.scala
+++ b/src/main/scala/com/twitter/querulous/Timeout.scala
@@ -1,7 +1,7 @@
package com.twitter.querulous
import java.util.{Timer, TimerTask}
-import com.twitter.xrayspecs.Duration
+import com.twitter.util.Duration
class TimeoutException extends Exception
@@ -20,7 +20,9 @@ object Timeout {
} finally {
task map { t =>
t.cancel()
- timer.purge()
+ // TODO(benjy): Timer is not optimized to deal with large numbers of cancellations: it releases and reacquires its monitor
+ // on every task, cancelled or not, when it could quickly skip over all cancelled tasks in a single monitor region.
+ // This may not be a problem, but it's something to be aware of.
}
if (cancelled) throw new TimeoutException
}
@@ -32,7 +34,15 @@ object Timeout {
private def schedule(timer: Timer, timeout: Duration, f: => Unit) = {
val task = new TimerTask() {
- override def run() { f }
+ override def run() {
+ try {
+ f
+ } catch {
+ case e: Throwable =>
+ error("Timer task tried to throw an exception: " + e.toString())
+ e.printStackTrace(System.err)
+ }
+ }
}
timer.schedule(task, timeout.inMillis)
task
diff --git a/src/main/scala/com/twitter/querulous/config/Database.scala b/src/main/scala/com/twitter/querulous/config/Database.scala
new file mode 100644
index 0000000..cd8c650
--- /dev/null
+++ b/src/main/scala/com/twitter/querulous/config/Database.scala
@@ -0,0 +1,125 @@
+package com.twitter.querulous.config
+
+import com.twitter.querulous._
+import com.twitter.util.Duration
+import com.twitter.util.TimeConversions._
+import database._
+
+
+class ApachePoolingDatabase {
+ var sizeMin: Int = 10
+ var sizeMax: Int = 10
+ var testIdle: Duration = 1.second
+ var maxWait: Duration = 10.millis
+ var minEvictableIdle: Duration = 60.seconds
+ var testOnBorrow: Boolean = false
+}
+
+class TimingOutDatabase {
+ var poolSize: Int = 10
+ var queueSize: Int = 10000
+ var open: Duration = 1.second
+}
+
+trait AutoDisablingDatabase {
+ def errorCount: Int
+ def interval: Duration
+}
+
+class Database {
+ var pool: Option[ApachePoolingDatabase] = None
+ def pool_=(p: ApachePoolingDatabase) { pool = Some(p) }
+ var autoDisable: Option[AutoDisablingDatabase] = None
+ def autoDisable_=(a: AutoDisablingDatabase) { autoDisable = Some(a) }
+ var timeout: Option[TimingOutDatabase] = None
+ def timeout_=(t: TimingOutDatabase) { timeout = Some(t) }
+ var memoize: Boolean = true
+
+ def apply(stats: StatsCollector): DatabaseFactory = {
+ var factory: DatabaseFactory = pool.map(apacheConfig =>
+ new ApachePoolingDatabaseFactory(
+ apacheConfig.sizeMin,
+ apacheConfig.sizeMax,
+ apacheConfig.testIdle,
+ apacheConfig.maxWait,
+ apacheConfig.testOnBorrow,
+ apacheConfig.minEvictableIdle)
+ ).getOrElse(new SingleConnectionDatabaseFactory)
+
+ timeout.foreach { timeoutConfig =>
+ factory = new TimingOutDatabaseFactory(factory,
+ timeoutConfig.poolSize,
+ timeoutConfig.queueSize,
+ timeoutConfig.open,
+ timeoutConfig.poolSize)
+ }
+
+ if (stats ne NullStatsCollector) {
+ factory = new StatsCollectingDatabaseFactory(factory, stats)
+ }
+
+ autoDisable.foreach { disable =>
+ factory = new AutoDisablingDatabaseFactory(factory, disable.errorCount, disable.interval)
+ }
+
+ if (memoize) {
+ factory = new MemoizingDatabaseFactory(factory)
+ }
+
+ factory
+ }
+
+ def apply(): DatabaseFactory = apply(NullStatsCollector)
+}
+
+trait Connection {
+ def hostnames: Seq[String]
+ def database: String
+ def username: String
+ def password: String
+ var urlOptions: Map[String, String] = Map()
+
+ def withHost(newHost: String) = {
+ val current = this
+ new Connection {
+ def hostnames = Seq(newHost)
+ def database = current.database
+ def username = current.username
+ def password = current.password
+ urlOptions = current.urlOptions
+ }
+ }
+
+ def withHosts(newHosts: Seq[String]) = {
+ val current = this
+ new Connection {
+ def hostnames = newHosts
+ def database = current.database
+ def username = current.username
+ def password = current.password
+ urlOptions = current.urlOptions
+ }
+ }
+
+ def withDatabase(newDatabase: String) = {
+ val current = this
+ new Connection {
+ def hostnames = current.hostnames
+ def database = newDatabase
+ def username = current.username
+ def password = current.password
+ urlOptions = current.urlOptions
+ }
+ }
+
+ def withoutDatabase = {
+ val current = this
+ new Connection {
+ def hostnames = current.hostnames
+ def database = null
+ def username = current.username
+ def password = current.password
+ urlOptions = current.urlOptions
+ }
+ }
+}
diff --git a/src/main/scala/com/twitter/querulous/config/Query.scala b/src/main/scala/com/twitter/querulous/config/Query.scala
new file mode 100644
index 0000000..2467a17
--- /dev/null
+++ b/src/main/scala/com/twitter/querulous/config/Query.scala
@@ -0,0 +1,59 @@
+package com.twitter.querulous.config
+
+import com.twitter.querulous._
+import com.twitter.util.Duration
+import com.twitter.util.TimeConversions._
+import query._
+
+
+object QueryTimeout {
+ def apply(timeout: Duration, cancelOnTimeout: Boolean) =
+ new QueryTimeout(timeout, cancelOnTimeout)
+
+ def apply(timeout: Duration) =
+ new QueryTimeout(timeout, false)
+}
+
+class QueryTimeout(val timeout: Duration, val cancelOnTimeout: Boolean)
+
+object NoDebugOutput extends (String => Unit) {
+ def apply(s: String) = ()
+}
+
+class Query {
+ var timeouts: Map[QueryClass, QueryTimeout] = Map(
+ QueryClass.Select -> QueryTimeout(5.seconds),
+ QueryClass.Execute -> QueryTimeout(5.seconds)
+ )
+
+ var retries: Int = 0
+ var debug: (String => Unit) = NoDebugOutput
+
+ def apply(statsCollector: StatsCollector): QueryFactory = {
+ var queryFactory: QueryFactory = new SqlQueryFactory
+
+ if (!timeouts.isEmpty) {
+ val tupleTimeout = Map(timeouts.map { case (queryClass, timeout) =>
+ (queryClass, (timeout.timeout, timeout.cancelOnTimeout))
+ }.toList: _*)
+
+ queryFactory = new PerQueryTimingOutQueryFactory(new SqlQueryFactory, tupleTimeout)
+ }
+
+ if (statsCollector ne NullStatsCollector) {
+ queryFactory = new StatsCollectingQueryFactory(queryFactory, statsCollector)
+ }
+
+ if (retries > 0) {
+ queryFactory = new RetryingQueryFactory(queryFactory, retries)
+ }
+
+ if (debug ne NoDebugOutput) {
+ queryFactory = new DebuggingQueryFactory(queryFactory, debug)
+ }
+
+ queryFactory
+ }
+
+ def apply(): QueryFactory = apply(NullStatsCollector)
+}
diff --git a/src/main/scala/com/twitter/querulous/config/QueryEvaluator.scala b/src/main/scala/com/twitter/querulous/config/QueryEvaluator.scala
new file mode 100644
index 0000000..9fc09ff
--- /dev/null
+++ b/src/main/scala/com/twitter/querulous/config/QueryEvaluator.scala
@@ -0,0 +1,33 @@
+package com.twitter.querulous.config
+
+import com.twitter.querulous._
+import com.twitter.util.Duration
+import evaluator._
+
+trait AutoDisablingQueryEvaluator {
+ def errorCount: Int
+ def interval: Duration
+}
+
+class QueryEvaluator {
+ var database: Database = new Database
+ var query: Query = new Query
+
+ var autoDisable: Option[AutoDisablingQueryEvaluator] = None
+ def autoDisable_=(a: AutoDisablingQueryEvaluator) { autoDisable = Some(a) }
+
+
+ def apply(stats: StatsCollector): QueryEvaluatorFactory = {
+ var factory: QueryEvaluatorFactory =
+ new StandardQueryEvaluatorFactory(database(stats), query(stats))
+
+ autoDisable.foreach { disable =>
+ factory = new AutoDisablingQueryEvaluatorFactory(
+ factory, disable.errorCount, disable.interval
+ )
+ }
+ factory
+ }
+
+ def apply(): QueryEvaluatorFactory = apply(NullStatsCollector)
+}
diff --git a/src/main/scala/com/twitter/querulous/database/ApachePoolingDatabase.scala b/src/main/scala/com/twitter/querulous/database/ApachePoolingDatabase.scala
index d2c590f..5ab6213 100644
--- a/src/main/scala/com/twitter/querulous/database/ApachePoolingDatabase.scala
+++ b/src/main/scala/com/twitter/querulous/database/ApachePoolingDatabase.scala
@@ -3,30 +3,42 @@ package com.twitter.querulous.database
import java.sql.{SQLException, Connection}
import org.apache.commons.dbcp.{PoolableConnectionFactory, DriverManagerConnectionFactory, PoolingDataSource}
import org.apache.commons.pool.impl.{GenericObjectPool, StackKeyedObjectPoolFactory}
-import com.twitter.xrayspecs.Duration
+import com.twitter.util.Duration
class ApachePoolingDatabaseFactory(
- minOpenConnections: Int,
- maxOpenConnections: Int,
+ val minOpenConnections: Int,
+ val maxOpenConnections: Int,
checkConnectionHealthWhenIdleFor: Duration,
maxWaitForConnectionReservation: Duration,
checkConnectionHealthOnReservation: Boolean,
- evictConnectionIfIdleFor: Duration) extends DatabaseFactory {
+ evictConnectionIfIdleFor: Duration,
+ defaultUrlOptions: Map[String, String]) extends DatabaseFactory {
+
+ def this(minConns: Int, maxConns: Int, checkIdle: Duration, maxWait: Duration, checkHealth: Boolean, evictTime: Duration) = {
+ this(minConns, maxConns, checkIdle, maxWait, checkHealth, evictTime, Map.empty)
+ }
def apply(dbhosts: List[String], dbname: String, username: String, password: String, urlOptions: Map[String, String]) = {
- val pool = new ApachePoolingDatabase(
+ val finalUrlOptions =
+ if (urlOptions eq null) {
+ defaultUrlOptions
+ } else {
+ defaultUrlOptions ++ urlOptions
+ }
+
+ new ApachePoolingDatabase(
dbhosts,
dbname,
username,
password,
- urlOptions,
+ finalUrlOptions,
minOpenConnections,
maxOpenConnections,
checkConnectionHealthWhenIdleFor,
maxWaitForConnectionReservation,
checkConnectionHealthOnReservation,
- evictConnectionIfIdleFor)
- pool
+ evictConnectionIfIdleFor
+ )
}
}
@@ -80,5 +92,5 @@ class ApachePoolingDatabase(
def open() = poolingDataSource.getConnection()
- override def toString = dbhosts.first + "_" + dbname
+ override def toString = dbhosts.head + "_" + dbname
}
diff --git a/src/main/scala/com/twitter/querulous/database/AutoDisablingDatabase.scala b/src/main/scala/com/twitter/querulous/database/AutoDisablingDatabase.scala
index ed1c29f..1f14388 100644
--- a/src/main/scala/com/twitter/querulous/database/AutoDisablingDatabase.scala
+++ b/src/main/scala/com/twitter/querulous/database/AutoDisablingDatabase.scala
@@ -1,15 +1,16 @@
package com.twitter.querulous.database
-import com.twitter.xrayspecs.Duration
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.querulous.AutoDisabler
+import com.twitter.util.Duration
+import com.twitter.util.TimeConversions._
import java.sql.{Connection, SQLException, SQLIntegrityConstraintViolationException}
-class AutoDisablingDatabaseFactory(databaseFactory: DatabaseFactory, disableErrorCount: Int, disableDuration: Duration) extends DatabaseFactory {
+class AutoDisablingDatabaseFactory(val databaseFactory: DatabaseFactory, val disableErrorCount: Int, val disableDuration: Duration) extends DatabaseFactory {
def apply(dbhosts: List[String], dbname: String, username: String, password: String, urlOptions: Map[String, String]) = {
new AutoDisablingDatabase(
databaseFactory(dbhosts, dbname, username, password, urlOptions),
- dbhosts.first,
+ dbhosts.head,
disableErrorCount,
disableDuration)
}
diff --git a/src/main/scala/com/twitter/querulous/database/Database.scala b/src/main/scala/com/twitter/querulous/database/Database.scala
index d6d0c78..abcdbb8 100644
--- a/src/main/scala/com/twitter/querulous/database/Database.scala
+++ b/src/main/scala/com/twitter/querulous/database/Database.scala
@@ -1,46 +1,18 @@
package com.twitter.querulous.database
+import com.twitter.querulous._
import java.sql.Connection
-import com.twitter.xrayspecs.TimeConversions._
-import net.lag.configgy.ConfigMap
-
-
-object DatabaseFactory {
- def fromConfig(config: ConfigMap, statsCollector: Option[StatsCollector]) = {
- var factory: DatabaseFactory = if (config.contains("size_min")) {
- new ApachePoolingDatabaseFactory(
- config("size_min").toInt,
- config("size_max").toInt,
- config("test_idle_msec").toLong.millis,
- config("max_wait").toLong.millis,
- config("test_on_borrow").toBoolean,
- config("min_evictable_idle_msec").toLong.millis)
- } else {
- new SingleConnectionDatabaseFactory()
- }
- statsCollector.foreach { stats =>
- factory = new StatsCollectingDatabaseFactory(factory, stats)
- }
- config.getConfigMap("timeout").foreach { timeoutConfig =>
- factory = new TimingOutDatabaseFactory(factory,
- timeoutConfig("pool_size").toInt,
- timeoutConfig("queue_size").toInt,
- timeoutConfig("open").toLong.millis,
- timeoutConfig("initialize").toLong.millis,
- config("size_max").toInt)
- }
- new MemoizingDatabaseFactory(factory)
- }
-}
+import com.twitter.querulous.StatsCollector
+import com.twitter.util.TimeConversions._
trait DatabaseFactory {
def apply(dbhosts: List[String], dbname: String, username: String, password: String, urlOptions: Map[String, String]): Database
def apply(dbhosts: List[String], dbname: String, username: String, password: String): Database =
- apply(dbhosts, dbname, username, password, null)
+ apply(dbhosts, dbname, username, password, Map.empty)
def apply(dbhosts: List[String], username: String, password: String): Database =
- apply(dbhosts, null, username, password, null)
+ apply(dbhosts, null, username, password, Map.empty)
}
trait Database {
@@ -57,14 +29,18 @@ trait Database {
}
}
+ val defaultUrlOptions = Map(
+ "useUnicode" -> "true",
+ "characterEncoding" -> "UTF-8",
+ "connectTimeout" -> "500"
+ )
+
protected def url(dbhosts: List[String], dbname: String, urlOptions: Map[String, String]) = {
val dbnameSegment = if (dbname == null) "" else ("/" + dbname)
- val urlOptsSegment = if (urlOptions == null) {
- "?useUnicode=true&characterEncoding=UTF-8"
- } else {
- "?" + urlOptions.keys.map( k => k + "=" + urlOptions(k) ).mkString("&")
- }
- "jdbc:mysql://" + dbhosts.mkString(",") + dbnameSegment + urlOptsSegment
+ val finalUrlOpts = defaultUrlOptions ++ urlOptions
+ val urlOptsSegment = finalUrlOpts.map(Function.tupled((k, v) => k+"="+v )).mkString("&")
+
+ "jdbc:mysql://" + dbhosts.mkString(",") + dbnameSegment + "?" + urlOptsSegment
}
}
diff --git a/src/main/scala/com/twitter/querulous/database/MemoizingDatabase.scala b/src/main/scala/com/twitter/querulous/database/MemoizingDatabase.scala
index ec0ffd7..e13bf5d 100644
--- a/src/main/scala/com/twitter/querulous/database/MemoizingDatabase.scala
+++ b/src/main/scala/com/twitter/querulous/database/MemoizingDatabase.scala
@@ -2,12 +2,12 @@ package com.twitter.querulous.database
import scala.collection.mutable
-class MemoizingDatabaseFactory(databaseFactory: DatabaseFactory) extends DatabaseFactory {
+class MemoizingDatabaseFactory(val databaseFactory: DatabaseFactory) extends DatabaseFactory {
private val databases = new mutable.HashMap[String, Database] with mutable.SynchronizedMap[String, Database]
def apply(dbhosts: List[String], dbname: String, username: String, password: String, urlOptions: Map[String, String]) = synchronized {
databases.getOrElseUpdate(
- dbhosts.first + "/" + dbname,
+ dbhosts.head + "/" + dbname,
databaseFactory(dbhosts, dbname, username, password, urlOptions))
}
diff --git a/src/main/scala/com/twitter/querulous/database/SingleConnectionDatabase.scala b/src/main/scala/com/twitter/querulous/database/SingleConnectionDatabase.scala
index d582963..44359da 100644
--- a/src/main/scala/com/twitter/querulous/database/SingleConnectionDatabase.scala
+++ b/src/main/scala/com/twitter/querulous/database/SingleConnectionDatabase.scala
@@ -4,9 +4,18 @@ import org.apache.commons.dbcp.DriverManagerConnectionFactory
import java.sql.{SQLException, Connection}
-class SingleConnectionDatabaseFactory extends DatabaseFactory {
+class SingleConnectionDatabaseFactory(defaultUrlOptions: Map[String, String]) extends DatabaseFactory {
+ def this() = this(Map.empty)
+
def apply(dbhosts: List[String], dbname: String, username: String, password: String, urlOptions: Map[String, String]) = {
- new SingleConnectionDatabase(dbhosts, dbname, username, password, urlOptions)
+ val finalUrlOptions =
+ if (urlOptions eq null) {
+ defaultUrlOptions
+ } else {
+ defaultUrlOptions ++ urlOptions
+ }
+
+ new SingleConnectionDatabase(dbhosts, dbname, username, password, finalUrlOptions)
}
}
@@ -24,5 +33,5 @@ class SingleConnectionDatabase(dbhosts: List[String], dbname: String, username:
}
def open() = connectionFactory.createConnection()
- override def toString = dbhosts.first + "_" + dbname
+ override def toString = dbhosts.head + "_" + dbname
}
diff --git a/src/main/scala/com/twitter/querulous/database/StatsCollectingDatabase.scala b/src/main/scala/com/twitter/querulous/database/StatsCollectingDatabase.scala
index 8d77d31..d8af7ef 100644
--- a/src/main/scala/com/twitter/querulous/database/StatsCollectingDatabase.scala
+++ b/src/main/scala/com/twitter/querulous/database/StatsCollectingDatabase.scala
@@ -1,5 +1,6 @@
package com.twitter.querulous.database
+import com.twitter.querulous.StatsCollector
import java.sql.Connection
class StatsCollectingDatabaseFactory(
@@ -15,14 +16,26 @@ class StatsCollectingDatabase(database: Database, stats: StatsCollector)
extends Database {
override def open(): Connection = {
- stats.time("database-open-timing") {
- database.open()
+ stats.time("db-open-timing") {
+ try {
+ database.open()
+ } catch {
+ case e: SqlDatabaseTimeoutException =>
+ stats.incr("db-open-timeout-count", 1)
+ throw e
+ }
}
}
override def close(connection: Connection) = {
- stats.time("database-close-timing") {
- database.close(connection)
+ stats.time("db-close-timing") {
+ try {
+ database.close(connection)
+ } catch {
+ case e: SqlDatabaseTimeoutException =>
+ stats.incr("db-close-timeout-count", 1)
+ throw e
+ }
}
}
}
diff --git a/src/main/scala/com/twitter/querulous/database/TimingOutDatabase.scala b/src/main/scala/com/twitter/querulous/database/TimingOutDatabase.scala
index d98b582..a4b9751 100644
--- a/src/main/scala/com/twitter/querulous/database/TimingOutDatabase.scala
+++ b/src/main/scala/com/twitter/querulous/database/TimingOutDatabase.scala
@@ -1,28 +1,35 @@
package com.twitter.querulous.database
+import com.twitter.querulous.{FutureTimeout, TimeoutException}
import java.sql.{Connection, SQLException}
import java.util.concurrent.{TimeoutException => JTimeoutException, _}
-import com.twitter.xrayspecs.Duration
-import net.lag.logging.Logger
+import com.twitter.util.Duration
class SqlDatabaseTimeoutException(msg: String, val timeout: Duration) extends SQLException(msg)
-class TimingOutDatabaseFactory(databaseFactory: DatabaseFactory, poolSize: Int, queueSize: Int, openTimeout: Duration, initialTimeout: Duration, maxConnections: Int) extends DatabaseFactory {
- def apply(dbhosts: List[String], dbname: String, username: String, password: String, urlOptions: Map[String, String]) = {
+class TimingOutDatabaseFactory(
+ val databaseFactory: DatabaseFactory,
+ val poolSize: Int,
+ val queueSize: Int,
+ val openTimeout: Duration,
+ val maxConnections: Int)
+extends DatabaseFactory {
+
+ private def newTimeoutPool() = new FutureTimeout(poolSize, queueSize)
+
+ def apply(dbhosts: List[String], dbname: String, username: String, password: String,
+ urlOptions: Map[String, String]) = {
val dbLabel = if (dbname != null) dbname else "(null)"
- new TimingOutDatabase(databaseFactory(dbhosts, dbname, username, password, urlOptions), dbhosts, dbLabel, poolSize, queueSize, openTimeout, initialTimeout, maxConnections)
+ new TimingOutDatabase(databaseFactory(dbhosts, dbname, username, password, urlOptions),
+ dbhosts, dbLabel, newTimeoutPool(), openTimeout, maxConnections)
}
}
-class TimingOutDatabase(database: Database, dbhosts: List[String], dbname: String, poolSize: Int, queueSize: Int, openTimeout: Duration, initialTimeout: Duration, maxConnections: Int) extends Database {
- private val timeout = new FutureTimeout(poolSize, queueSize)
- private val log = Logger.get(getClass.getName)
-
- // FIXME not working yet.
- //greedilyInstantiateConnections()
-
+class TimingOutDatabase(database: Database, dbhosts: List[String], dbname: String,
+ timeout: FutureTimeout, openTimeout: Duration,
+ maxConnections: Int) extends Database {
private def getConnection(wait: Duration) = {
try {
timeout(wait) {
@@ -36,13 +43,6 @@ class TimingOutDatabase(database: Database, dbhosts: List[String], dbname: Strin
}
}
- private def greedilyInstantiateConnections() = {
- log.info("Connecting to %s:%s", dbhosts.mkString(","), dbname)
- (0 until maxConnections).force.map { i =>
- getConnection(initialTimeout)
- }.map(_.close)
- }
-
override def open() = getConnection(openTimeout)
def close(connection: Connection) { database.close(connection) }
diff --git a/src/main/scala/com/twitter/querulous/evaluator/AutoDisablingQueryEvaluator.scala b/src/main/scala/com/twitter/querulous/evaluator/AutoDisablingQueryEvaluator.scala
index f4f6170..0230b20 100644
--- a/src/main/scala/com/twitter/querulous/evaluator/AutoDisablingQueryEvaluator.scala
+++ b/src/main/scala/com/twitter/querulous/evaluator/AutoDisablingQueryEvaluator.scala
@@ -3,8 +3,10 @@ package com.twitter.querulous.evaluator
import java.sql.ResultSet
import java.sql.{SQLException, SQLIntegrityConstraintViolationException}
import com.mysql.jdbc.exceptions.MySQLIntegrityConstraintViolationException
-import com.twitter.xrayspecs.{Time, Duration}
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.querulous.AutoDisabler
+import com.twitter.util.{Time, Duration}
+import com.twitter.util.TimeConversions._
+
class AutoDisablingQueryEvaluatorFactory(
queryEvaluatorFactory: QueryEvaluatorFactory,
@@ -28,9 +30,6 @@ class AutoDisablingQueryEvaluator (
protected val disableErrorCount: Int,
protected val disableDuration: Duration) extends QueryEvaluatorProxy(queryEvaluator) with AutoDisabler {
- private var disabledUntil: Time = Time.never
- private var consecutiveErrors = 0
-
override protected def delegate[A](f: => A) = {
throwIfDisabled()
try {
diff --git a/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluator.scala b/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluator.scala
index af68ed4..1de3531 100644
--- a/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluator.scala
+++ b/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluator.scala
@@ -1,29 +1,12 @@
package com.twitter.querulous.evaluator
+import com.twitter.querulous._
import java.sql.ResultSet
-import com.twitter.xrayspecs.TimeConversions._
-import net.lag.configgy.ConfigMap
-import database._
-import query._
-
-
-object QueryEvaluatorFactory {
- def fromConfig(config: ConfigMap, databaseFactory: DatabaseFactory, queryFactory: QueryFactory): QueryEvaluatorFactory = {
- var factory: QueryEvaluatorFactory = new StandardQueryEvaluatorFactory(databaseFactory, queryFactory)
- config.getConfigMap("disable").foreach { disableConfig =>
- factory = new AutoDisablingQueryEvaluatorFactory(factory,
- disableConfig("error_count").toInt,
- disableConfig("seconds").toInt.seconds)
- }
- factory
- }
+import com.twitter.util.TimeConversions._
+import com.twitter.querulous.StatsCollector
+import com.twitter.querulous.database._
+import com.twitter.querulous.query._
- def fromConfig(config: ConfigMap, statsCollector: Option[StatsCollector]): QueryEvaluatorFactory = {
- fromConfig(config,
- DatabaseFactory.fromConfig(config.configMap("connection_pool"), statsCollector),
- QueryFactory.fromConfig(config, statsCollector))
- }
-}
object QueryEvaluator extends QueryEvaluatorFactory {
private def createEvaluatorFactory() = {
@@ -45,39 +28,56 @@ trait QueryEvaluatorFactory {
}
def apply(dbhosts: List[String], dbname: String, username: String, password: String): QueryEvaluator = {
- apply(dbhosts, dbname, username, password, null)
+ apply(dbhosts, dbname, username, password, Map[String,String]())
}
def apply(dbhost: String, dbname: String, username: String, password: String): QueryEvaluator = {
- apply(List(dbhost), dbname, username, password, null)
+ apply(List(dbhost), dbname, username, password, Map[String,String]())
}
def apply(dbhost: String, username: String, password: String): QueryEvaluator = {
- apply(List(dbhost), null, username, password, null)
+ apply(List(dbhost), null, username, password, Map[String,String]())
}
def apply(dbhosts: List[String], username: String, password: String): QueryEvaluator = {
- apply(dbhosts, null, username, password, null)
+ apply(dbhosts, null, username, password, Map[String,String]())
}
- def apply(config: ConfigMap): QueryEvaluator = {
- apply(
- config.getList("hostname").toList,
- config.getString("database").getOrElse(null),
- config("username"),
- config.getString("password").getOrElse(null),
- // this is so lame, why do I have to cast this back?
- config.getConfigMap("url_options").map(_.asMap.asInstanceOf[Map[String, String]]).getOrElse(null)
- )
+ def apply(connection: config.Connection): QueryEvaluator = {
+ apply(connection.hostnames.toList, connection.database, connection.username, connection.password, connection.urlOptions)
}
}
+class ParamsApplier(query: Query) {
+ def apply(params: Any*) = query.addParams(params)
+}
+
trait QueryEvaluator {
- def select[A](query: String, params: Any*)(f: ResultSet => A): Seq[A]
- def selectOne[A](query: String, params: Any*)(f: ResultSet => A): Option[A]
- def count(query: String, params: Any*): Int
- def execute(query: String, params: Any*): Int
+ def select[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A): Seq[A]
+ def select[A](query: String, params: Any*)(f: ResultSet => A): Seq[A] =
+ select(QueryClass.Select, query, params: _*)(f)
+
+ def selectOne[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A): Option[A]
+ def selectOne[A](query: String, params: Any*)(f: ResultSet => A): Option[A] =
+ selectOne(QueryClass.Select, query, params: _*)(f)
+
+ def count(queryClass: QueryClass, query: String, params: Any*): Int
+ def count(query: String, params: Any*): Int =
+ count(QueryClass.Select, query, params: _*)
+
+ def execute(queryClass: QueryClass, query: String, params: Any*): Int
+ def execute(query: String, params: Any*): Int =
+ execute(QueryClass.Execute, query, params: _*)
+
+ def executeBatch(queryClass: QueryClass, query: String)(f: ParamsApplier => Unit): Int
+ def executeBatch(query: String)(f: ParamsApplier => Unit): Int =
+ executeBatch(QueryClass.Execute, query)(f)
+
def nextId(tableName: String): Long
- def insert(query: String, params: Any*): Long
+
+ def insert(queryClass: QueryClass, query: String, params: Any*): Long
+ def insert(query: String, params: Any*): Long =
+ insert(QueryClass.Execute, query, params: _*)
+
def transaction[T](f: Transaction => T): T
}
diff --git a/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluatorProxy.scala b/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluatorProxy.scala
index b9fec3b..77177bc 100644
--- a/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluatorProxy.scala
+++ b/src/main/scala/com/twitter/querulous/evaluator/QueryEvaluatorProxy.scala
@@ -3,31 +3,37 @@ package com.twitter.querulous.evaluator
import java.sql.ResultSet
import java.sql.{SQLException, SQLIntegrityConstraintViolationException}
import com.mysql.jdbc.exceptions.MySQLIntegrityConstraintViolationException
-import com.twitter.xrayspecs.{Time, Duration}
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.util.{Time, Duration}
+import com.twitter.util.TimeConversions._
+import com.twitter.querulous.query.{QueryClass, Query}
+
abstract class QueryEvaluatorProxy(queryEvaluator: QueryEvaluator) extends QueryEvaluator {
- def select[A](query: String, params: Any*)(f: ResultSet => A) = {
- delegate(queryEvaluator.select(query, params: _*)(f))
+ def select[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = {
+ delegate(queryEvaluator.select(queryClass, query, params: _*)(f))
+ }
+
+ def selectOne[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = {
+ delegate(queryEvaluator.selectOne(queryClass, query, params: _*)(f))
}
- def selectOne[A](query: String, params: Any*)(f: ResultSet => A) = {
- delegate(queryEvaluator.selectOne(query, params: _*)(f))
+ def execute(queryClass: QueryClass, query: String, params: Any*) = {
+ delegate(queryEvaluator.execute(queryClass, query, params: _*))
}
- def execute(query: String, params: Any*) = {
- delegate(queryEvaluator.execute(query, params: _*))
+ def executeBatch(queryClass: QueryClass, query: String)(f: ParamsApplier => Unit) = {
+ delegate(queryEvaluator.executeBatch(queryClass, query)(f))
}
- def count(query: String, params: Any*) = {
- delegate(queryEvaluator.count(query, params: _*))
+ def count(queryClass: QueryClass, query: String, params: Any*) = {
+ delegate(queryEvaluator.count(queryClass, query, params: _*))
}
def nextId(tableName: String) = {
delegate(queryEvaluator.nextId(tableName))
}
- def insert(query: String, params: Any*) = {
+ def insert(queryClass: QueryClass, query: String, params: Any*) = {
delegate(queryEvaluator.insert(query, params: _*))
}
diff --git a/src/main/scala/com/twitter/querulous/evaluator/StandardQueryEvaluator.scala b/src/main/scala/com/twitter/querulous/evaluator/StandardQueryEvaluator.scala
index f0c7a7a..78e1e6b 100644
--- a/src/main/scala/com/twitter/querulous/evaluator/StandardQueryEvaluator.scala
+++ b/src/main/scala/com/twitter/querulous/evaluator/StandardQueryEvaluator.scala
@@ -4,7 +4,7 @@ import java.sql.ResultSet
import org.apache.commons.dbcp.{DriverManagerConnectionFactory, PoolableConnectionFactory, PoolingDataSource}
import org.apache.commons.pool.impl.GenericObjectPool
import com.twitter.querulous.database.{Database, DatabaseFactory}
-import com.twitter.querulous.query.QueryFactory
+import com.twitter.querulous.query.{Query, QueryClass, QueryFactory}
class StandardQueryEvaluatorFactory(
databaseFactory: DatabaseFactory,
@@ -19,12 +19,31 @@ class StandardQueryEvaluatorFactory(
class StandardQueryEvaluator(protected val database: Database, queryFactory: QueryFactory)
extends QueryEvaluator {
- def select[A](query: String, params: Any*)(f: ResultSet => A) = withTransaction(_.select(query, params: _*)(f))
- def selectOne[A](query: String, params: Any*)(f: ResultSet => A) = withTransaction(_.selectOne(query, params: _*)(f))
- def count(query: String, params: Any*) = withTransaction(_.count(query, params: _*))
- def execute(query: String, params: Any*) = withTransaction(_.execute(query, params: _*))
+ def select[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = {
+ withTransaction(_.select(queryClass, query, params: _*)(f))
+ }
+
+ def selectOne[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = {
+ withTransaction(_.selectOne(queryClass, query, params: _*)(f))
+ }
+
+ def count(queryClass: QueryClass, query: String, params: Any*) = {
+ withTransaction(_.count(queryClass, query, params: _*))
+ }
+
+ def execute(queryClass: QueryClass, query: String, params: Any*) = {
+ withTransaction(_.execute(queryClass, query, params: _*))
+ }
+
+ def executeBatch(queryClass: QueryClass, query: String)(f: ParamsApplier => Unit) = {
+ withTransaction(_.executeBatch(queryClass, query)(f))
+ }
+
def nextId(tableName: String) = withTransaction(_.nextId(tableName))
- def insert(query: String, params: Any*) = withTransaction(_.insert(query, params: _*))
+
+ def insert(queryClass: QueryClass, query: String, params: Any*) = {
+ withTransaction(_.insert(queryClass, query, params: _*))
+ }
def transaction[T](f: Transaction => T) = {
withTransaction { transaction =>
@@ -35,7 +54,9 @@ class StandardQueryEvaluator(protected val database: Database, queryFactory: Que
rv
} catch {
case e: Throwable =>
- transaction.rollback()
+ try {
+ transaction.rollback()
+ } catch { case _ => () }
throw e
}
}
diff --git a/src/main/scala/com/twitter/querulous/evaluator/Transaction.scala b/src/main/scala/com/twitter/querulous/evaluator/Transaction.scala
index e438412..6f0fe3c 100644
--- a/src/main/scala/com/twitter/querulous/evaluator/Transaction.scala
+++ b/src/main/scala/com/twitter/querulous/evaluator/Transaction.scala
@@ -1,24 +1,29 @@
package com.twitter.querulous.evaluator
import java.sql.{ResultSet, SQLException, SQLIntegrityConstraintViolationException, Connection}
-import com.twitter.querulous.query.QueryFactory
+import com.twitter.querulous.query.{QueryClass, QueryFactory, Query}
class Transaction(queryFactory: QueryFactory, connection: Connection) extends QueryEvaluator {
- def select[A](query: String, params: Any*)(f: ResultSet => A) = {
- queryFactory(connection, query, params: _*).select(f)
+ def select[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = {
+ queryFactory(connection, queryClass, query, params: _*).select(f)
}
- def selectOne[A](query: String, params: Any*)(f: ResultSet => A) = {
- val results = select(query, params: _*)(f)
- if (results.isEmpty) None else Some(results.first)
+ def selectOne[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = {
+ select(queryClass, query, params: _*)(f).headOption
}
- def count(query: String, params: Any*) = {
- selectOne(query, params: _*)(_.getInt("count(*)")) getOrElse 0
+ def count(queryClass: QueryClass, query: String, params: Any*) = {
+ selectOne(queryClass, query, params: _*)(_.getInt("count(*)")) getOrElse 0
}
- def execute(query: String, params: Any*) = {
- queryFactory(connection, query, params: _*).execute()
+ def execute(queryClass: QueryClass, query: String, params: Any*) = {
+ queryFactory(connection, queryClass, query, params: _*).execute()
+ }
+
+ def executeBatch(queryClass: QueryClass, queryString: String)(f: ParamsApplier => Unit) = {
+ val query = queryFactory(connection, queryClass, queryString)
+ f(new ParamsApplier(query))
+ query.execute
}
def nextId(tableName: String) = {
@@ -26,8 +31,8 @@ class Transaction(queryFactory: QueryFactory, connection: Connection) extends Qu
selectOne("SELECT LAST_INSERT_ID()") { _.getLong("LAST_INSERT_ID()") } getOrElse 0L
}
- def insert(query: String, params: Any*): Long = {
- execute(query, params: _*)
+ def insert(queryClass: QueryClass, query: String, params: Any*): Long = {
+ execute(queryClass, query, params: _*)
selectOne("SELECT LAST_INSERT_ID()") { _.getLong("LAST_INSERT_ID()") } getOrElse {
throw new SQLIntegrityConstraintViolationException
}
diff --git a/src/main/scala/com/twitter/querulous/query/DebuggingQuery.scala b/src/main/scala/com/twitter/querulous/query/DebuggingQuery.scala
index 0c48681..cad1323 100644
--- a/src/main/scala/com/twitter/querulous/query/DebuggingQuery.scala
+++ b/src/main/scala/com/twitter/querulous/query/DebuggingQuery.scala
@@ -3,8 +3,8 @@ package com.twitter.querulous.query
import java.sql.{Timestamp, Connection}
class DebuggingQueryFactory(queryFactory: QueryFactory, log: String => Unit) extends QueryFactory {
- def apply(connection: Connection, query: String, params: Any*) = {
- new DebuggingQuery(queryFactory(connection, query, params: _*), log, query, params)
+ def apply(connection: Connection, queryClass: QueryClass, query: String, params: Any*) = {
+ new DebuggingQuery(queryFactory(connection, queryClass, query, params: _*), log, query, params)
}
}
diff --git a/src/main/scala/com/twitter/querulous/query/Query.scala b/src/main/scala/com/twitter/querulous/query/Query.scala
index 22b50b4..873f01a 100644
--- a/src/main/scala/com/twitter/querulous/query/Query.scala
+++ b/src/main/scala/com/twitter/querulous/query/Query.scala
@@ -1,67 +1,19 @@
package com.twitter.querulous.query
+import com.twitter.querulous._
import java.sql.{ResultSet, Connection}
import scala.collection.mutable
-import com.twitter.xrayspecs.Duration
-import com.twitter.xrayspecs.TimeConversions._
-import net.lag.configgy.ConfigMap
-import net.lag.logging.Logger
-
+import com.twitter.querulous.StatsCollector
+import com.twitter.util.Duration
+import com.twitter.util.TimeConversions._
trait QueryFactory {
- def apply(connection: Connection, queryString: String, params: Any*): Query
+ def apply(connection: Connection, queryClass: QueryClass, queryString: String, params: Any*): Query
}
trait Query {
def select[A](f: ResultSet => A): Seq[A]
def execute(): Int
+ def addParams(params: Any*)
def cancel()
}
-
-object QueryFactory {
- private def convertConfigMap(queryMap: ConfigMap) = {
- val queryInfo = new mutable.HashMap[String, (String, Duration)]
- for (key <- queryMap.keys) {
- val pair = queryMap.getList(key)
- val query = pair(0)
- val timeout = pair(1).toLong.millis
- queryInfo += (query -> (key, timeout))
- }
- queryInfo
- }
-
- /*
- query_timeout_default = 3000
- queries {
- select_source_id_for_update = ["SELECT * FROM ? WHERE source_id = ? FOR UPDATE", 3000]
- }
- retries = 3
- debug = false
- */
- def fromConfig(config: ConfigMap, statsCollector: Option[StatsCollector]): QueryFactory = {
- var queryFactory: QueryFactory = new SqlQueryFactory
- config.getConfigMap("queries") match {
- case Some(queryMap) =>
- val queryInfo = convertConfigMap(queryMap)
- val timeout = config("query_timeout_default").toLong.millis
- queryFactory = new TimingOutStatsCollectingQueryFactory(queryFactory, queryInfo, timeout,
- statsCollector.get)
- case None =>
- config.getInt("query_timeout_default").foreach { timeout =>
- queryFactory = new TimingOutQueryFactory(queryFactory, timeout.millis)
- }
- statsCollector.foreach { stats =>
- queryFactory = new StatsCollectingQueryFactory(queryFactory, stats)
- }
- }
-
- config.getInt("retries").foreach { retries =>
- queryFactory = new RetryingQueryFactory(queryFactory, retries)
- }
- if (config.getBool("debug", false)) {
- val log = Logger.get(getClass.getName)
- queryFactory = new DebuggingQueryFactory(queryFactory, { s => log.debug(s) })
- }
- queryFactory
- }
-}
\ No newline at end of file
diff --git a/src/main/scala/com/twitter/querulous/query/QueryClass.scala b/src/main/scala/com/twitter/querulous/query/QueryClass.scala
new file mode 100644
index 0000000..39618e4
--- /dev/null
+++ b/src/main/scala/com/twitter/querulous/query/QueryClass.scala
@@ -0,0 +1,19 @@
+package com.twitter.querulous.query
+
+import scala.collection.mutable
+
+class QueryClass(val name: String)
+
+object QueryClass {
+ val classes = mutable.Map[String, QueryClass]()
+
+ def apply(name: String) = {
+ classes(name) = new QueryClass(name)
+ lookup(name)
+ }
+
+ def lookup(name: String) = classes(name)
+
+ val Select = QueryClass("select")
+ val Execute = QueryClass("execute")
+}
diff --git a/src/main/scala/com/twitter/querulous/query/QueryProxy.scala b/src/main/scala/com/twitter/querulous/query/QueryProxy.scala
index 145a8a7..697daaf 100644
--- a/src/main/scala/com/twitter/querulous/query/QueryProxy.scala
+++ b/src/main/scala/com/twitter/querulous/query/QueryProxy.scala
@@ -9,5 +9,7 @@ abstract class QueryProxy(query: Query) extends Query {
def cancel() = query.cancel()
+ def addParams(params: Any*) = query.addParams(params)
+
protected def delegate[A](f: => A) = f
}
diff --git a/src/main/scala/com/twitter/querulous/query/RetryingQuery.scala b/src/main/scala/com/twitter/querulous/query/RetryingQuery.scala
index 6189cf7..a413ff9 100644
--- a/src/main/scala/com/twitter/querulous/query/RetryingQuery.scala
+++ b/src/main/scala/com/twitter/querulous/query/RetryingQuery.scala
@@ -1,12 +1,12 @@
package com.twitter.querulous.query
import java.sql.{SQLException, Connection}
-import com.twitter.xrayspecs.Duration
+import com.twitter.util.Duration
class RetryingQueryFactory(queryFactory: QueryFactory, retries: Int) extends QueryFactory {
- def apply(connection: Connection, query: String, params: Any*) = {
- new RetryingQuery(queryFactory(connection, query, params: _*), retries)
+ def apply(connection: Connection, queryClass: QueryClass, query: String, params: Any*) = {
+ new RetryingQuery(queryFactory(connection, queryClass, query, params: _*), retries)
}
}
diff --git a/src/main/scala/com/twitter/querulous/query/SqlQuery.scala b/src/main/scala/com/twitter/querulous/query/SqlQuery.scala
index 60577ca..4e655d5 100644
--- a/src/main/scala/com/twitter/querulous/query/SqlQuery.scala
+++ b/src/main/scala/com/twitter/querulous/query/SqlQuery.scala
@@ -6,13 +6,13 @@ import java.util.regex.Pattern
import scala.collection.mutable
class SqlQueryFactory extends QueryFactory {
- def apply(connection: Connection, query: String, params: Any*) = {
+ def apply(connection: Connection, queryClass: QueryClass, query: String, params: Any*) = {
new SqlQuery(connection, query, params: _*)
}
}
-class TooFewQueryParametersException extends Exception
-class TooManyQueryParametersException extends Exception
+class TooFewQueryParametersException(t: Throwable) extends Exception(t)
+class TooManyQueryParametersException(t: Throwable) extends Exception(t)
sealed abstract case class NullValue(typeVal: Int)
object NullValues {
@@ -43,7 +43,13 @@ object NullValues {
class SqlQuery(connection: Connection, query: String, params: Any*) extends Query {
- val statement = buildStatement(connection, query, params: _*)
+ def this(connection: Connection, query: String) = {
+ this(connection, query, Nil)
+ }
+
+ var paramsInitialized = false
+ var statement = buildStatement(connection, query, params: _*)
+ var batchMode = false
def select[A](f: ResultSet => A): Seq[A] = {
withStatement {
@@ -61,9 +67,22 @@ class SqlQuery(connection: Connection, query: String, params: Any*) extends Quer
}
}
+ def addParams(params: Any*) = {
+ if(paramsInitialized && !batchMode) {
+ statement.addBatch()
+ }
+ setBindVariable(statement, 1, params)
+ statement.addBatch()
+ batchMode = true
+ }
+
def execute() = {
withStatement {
- statement.executeUpdate()
+ if(batchMode) {
+ statement.executeBatch().foldLeft(0)(_+_)
+ } else {
+ statement.executeUpdate()
+ }
}
}
@@ -94,26 +113,37 @@ class SqlQuery(connection: Connection, query: String, params: Any*) extends Quer
statement
}
- private def expandArrayParams(query: String, params: Any*) = {
+ private def expandArrayParams(query: String, params: Any*): String = {
+ if(params.isEmpty){
+ return query
+ }
val p = Pattern.compile("\\?")
val m = p.matcher(query)
val result = new StringBuffer
var i = 0
+
+ def marks(param: Any): String = param match {
+ case t2: (_,_) => "(?,?)"
+ case t3: (_,_,_) => "(?,?,?)"
+ case t4: (_,_,_,_) => "(?,?,?,?)"
+ case a: Array[Byte] => "?"
+ case s: Seq[_] => s.map(marks(_)).mkString(",")
+ case _ => "?"
+ }
+
while (m.find) {
try {
- val questionMarks = params(i) match {
- case a: Array[Byte] => "?"
- case s: Seq[_] => s.map { _ => "?" }.mkString(",")
- case _ => "?"
- }
- m.appendReplacement(result, questionMarks)
+ m.appendReplacement(result, marks(params(i)))
} catch {
- case e: ArrayIndexOutOfBoundsException => throw new TooFewQueryParametersException
- case e: NoSuchElementException => throw new TooFewQueryParametersException
+ case e: ArrayIndexOutOfBoundsException =>
+ throw new TooFewQueryParametersException(e)
+ case e: NoSuchElementException =>
+ throw new TooFewQueryParametersException(e)
}
i += 1
}
m.appendTail(result)
+ paramsInitialized = true
result.toString
}
@@ -122,6 +152,12 @@ class SqlQuery(connection: Connection, query: String, params: Any*) extends Quer
try {
param match {
+ case (a, b) =>
+ index = setBindVariable(statement, index, List(a, b)) - 1
+ case (a, b, c) =>
+ index = setBindVariable(statement, index, List(a, b, c)) - 1
+ case (a, b, c, d) =>
+ index = setBindVariable(statement, index, List(a, b, c, d)) - 1
case s: String =>
statement.setString(index, s)
case l: Long =>
@@ -146,7 +182,8 @@ class SqlQuery(connection: Connection, query: String, params: Any*) extends Quer
}
index + 1
} catch {
- case e: SQLException => throw new TooManyQueryParametersException
+ case e: SQLException =>
+ throw new TooManyQueryParametersException(e)
}
}
}
diff --git a/src/main/scala/com/twitter/querulous/query/StatsCollectingQuery.scala b/src/main/scala/com/twitter/querulous/query/StatsCollectingQuery.scala
index bbd74ea..9f89e64 100644
--- a/src/main/scala/com/twitter/querulous/query/StatsCollectingQuery.scala
+++ b/src/main/scala/com/twitter/querulous/query/StatsCollectingQuery.scala
@@ -1,27 +1,30 @@
package com.twitter.querulous.query
+import com.twitter.querulous.StatsCollector
import java.sql.{ResultSet, Connection}
class StatsCollectingQueryFactory(queryFactory: QueryFactory, stats: StatsCollector)
extends QueryFactory {
- def apply(connection: Connection, query: String, params: Any*) = {
- new StatsCollectingQuery(queryFactory(connection, query, params: _*), stats)
+ def apply(connection: Connection, queryClass: QueryClass, query: String, params: Any*) = {
+ new StatsCollectingQuery(queryFactory(connection, queryClass, query, params: _*), queryClass, stats)
}
}
-class StatsCollectingQuery(query: Query, stats: StatsCollector) extends QueryProxy(query) {
- override def select[A](f: ResultSet => A) = {
- stats.incr("db-select-count", 1)
- delegate(query.select(f))
- }
-
- override def execute() = {
- stats.incr("db-execute-count", 1)
- delegate(query.execute())
- }
-
+class StatsCollectingQuery(query: Query, queryClass: QueryClass, stats: StatsCollector) extends QueryProxy(query) {
override def delegate[A](f: => A) = {
- stats.time("db-timing")(f)
+ stats.incr("db-" + queryClass.name + "-count", 1)
+ stats.time("db-" + queryClass.name + "-timing") {
+ stats.time("db-timing") {
+ try {
+ f
+ } catch {
+ case e: SqlQueryTimeoutException =>
+ stats.incr("db-query-timeout-count", 1)
+ stats.incr("db-query-" + queryClass.name + "-timeout-count", 1)
+ throw e
+ }
+ }
+ }
}
}
diff --git a/src/main/scala/com/twitter/querulous/query/TimingOutQuery.scala b/src/main/scala/com/twitter/querulous/query/TimingOutQuery.scala
index 02c88e8..ecb7631 100644
--- a/src/main/scala/com/twitter/querulous/query/TimingOutQuery.scala
+++ b/src/main/scala/com/twitter/querulous/query/TimingOutQuery.scala
@@ -1,8 +1,10 @@
package com.twitter.querulous.query
import java.sql.{SQLException, Connection}
-import com.twitter.xrayspecs.Duration
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.util.Duration
+import com.twitter.util.TimeConversions._
+import scala.collection.Map
+import com.twitter.querulous.{Timeout, TimeoutException}
class SqlQueryTimeoutException(val timeout: Duration) extends SQLException("Query timeout: " + timeout.inMillis + " msec")
@@ -14,29 +16,26 @@ class SqlQueryTimeoutException(val timeout: Duration) extends SQLException("Quer
*
Note that queries timing out promptly is based upon {@link java.sql.Statement#cancel} working
* and executing promptly for the JDBC driver in use.
*/
-class TimingOutQueryFactory(queryFactory: QueryFactory, timeout: Duration, cancelTimeout: Duration) extends QueryFactory {
- def this(queryFactory: QueryFactory, timeout: Duration) = this(queryFactory, timeout, 0.millis)
+class TimingOutQueryFactory(queryFactory: QueryFactory, val timeout: Duration, val cancelOnTimeout: Boolean)
+ extends QueryFactory {
+
+ def this(queryFactory: QueryFactory, timeout: Duration) = this(queryFactory, timeout, false)
- def apply(connection: Connection, query: String, params: Any*) = {
- new TimingOutQuery(queryFactory(connection, query, params: _*), connection, timeout, cancelTimeout)
+ def apply(connection: Connection, queryClass: QueryClass, query: String, params: Any*) = {
+ new TimingOutQuery(queryFactory(connection, queryClass, query, params: _*), connection, timeout, cancelOnTimeout)
}
}
/**
- * A {@code QueryFactory} that creates {@link Query}s that execute subject to the {@code timeouts}
- * specified for individual queries. An attempt to {@link Query#cancel} a query is made if the
- * timeout expires.
- *
- *
Note that queries timing out promptly is based upon {@link java.sql.Statement#cancel} working
- * and executing promptly for the JDBC driver in use.
+ * A `QueryFactory` that creates `Query`s that execute subject to the timeouts
+ * specified for individual query classes.
*/
-class PerQueryTimingOutQueryFactory(queryFactory: QueryFactory, timeouts: Map[String, Duration], cancelTimeout: Duration)
+class PerQueryTimingOutQueryFactory(queryFactory: QueryFactory, val timeouts: Map[QueryClass, (Duration, Boolean)])
extends QueryFactory {
- def this(queryFactory: QueryFactory, timeouts: Map[String, Duration]) = this(queryFactory, timeouts, 0.millis)
-
- def apply(connection: Connection, query: String, params: Any*) = {
- new TimingOutQuery(queryFactory(connection, query, params: _*), connection, timeouts(query), cancelTimeout)
+ def apply(connection: Connection, queryClass: QueryClass, query: String, params: Any*) = {
+ val (timeout, cancelOnTimeout) = timeouts(queryClass)
+ new TimingOutQuery(queryFactory(connection, queryClass, query, params: _*), connection, timeout, cancelOnTimeout)
}
}
@@ -51,21 +50,21 @@ private object QueryCancellation {
*
Note that the query timing out promptly is based upon {@link java.sql.Statement#cancel}
* working and executing promptly for the JDBC driver in use.
*/
-class TimingOutQuery(query: Query, connection: Connection, timeout: Duration, cancelTimeout: Duration)
+class TimingOutQuery(query: Query, connection: Connection, timeout: Duration, cancelOnTimeout: Boolean)
extends QueryProxy(query) with ConnectionDestroying {
+ def this(query: Query, connection: Connection, timeout: Duration) = this(query, connection, timeout, false)
+
import QueryCancellation._
override def delegate[A](f: => A) = {
try {
- Timeout(timeout) {
- f
- } {
- cancel()
+ Timeout(cancelTimer, timeout)(f) {
+ if (cancelOnTimeout) cancel()
+ destroyConnection(connection)
}
} catch {
- case e: TimeoutException =>
- throw new SqlQueryTimeoutException(timeout)
+ case e: TimeoutException => throw new SqlQueryTimeoutException(timeout)
}
}
@@ -73,14 +72,10 @@ class TimingOutQuery(query: Query, connection: Connection, timeout: Duration, ca
val cancelThread = new Thread("query cancellation") {
override def run() {
try {
- Timeout(cancelTimer, cancelTimeout) {
- // start by trying the nice way
- query.cancel()
- } {
- // if the cancel times out, destroy the underlying connection
- destroyConnection(connection)
- }
- } catch { case e: TimeoutException => }
+ // This cancel may block, as it has to connect to the database.
+ // If the default socket connection timeout has been removed, this thread will run away.
+ query.cancel()
+ } catch { case e => () }
}
}
cancelThread.start()
diff --git a/src/main/scala/com/twitter/querulous/query/TimingOutStatsCollectingQuery.scala b/src/main/scala/com/twitter/querulous/query/TimingOutStatsCollectingQuery.scala
deleted file mode 100644
index c1ec893..0000000
--- a/src/main/scala/com/twitter/querulous/query/TimingOutStatsCollectingQuery.scala
+++ /dev/null
@@ -1,67 +0,0 @@
-package com.twitter.querulous.query
-
-import java.sql.{Connection, ResultSet}
-import scala.collection.Map
-import scala.util.matching.Regex
-import scala.collection.Map
-import com.twitter.xrayspecs.Duration
-import com.twitter.xrayspecs.TimeConversions._
-import net.lag.extensions._
-
-
-object TimingOutStatsCollectingQueryFactory {
- val TABLE_NAME = """(FROM|UPDATE|INSERT INTO|LIMIT)\s+[\w-]+""".r
- val DDL_QUERY = """^\s*((CREATE|DROP|ALTER)\s+(TABLE|DATABASE)|DESCRIBE)\s+""".r
-
- def simplifiedQuery(query: String) = {
- if (DDL_QUERY.findFirstMatchIn(query).isDefined) {
- "default"
- } else {
- query.regexSub(TABLE_NAME) { m => m.group(1) + "?" }
- }
- }
-}
-
-class TimingOutStatsCollectingQueryFactory(queryFactory: QueryFactory,
- queryInfo: Map[String, (String, Duration)],
- defaultTimeout: Duration, cancelTimeout: Duration, stats: StatsCollector)
- extends QueryFactory {
-
- def this(queryFactory: QueryFactory, queryInfo: Map[String, (String, Duration)], defaultTimeout: Duration, stats: StatsCollector) =
- this(queryFactory, queryInfo, defaultTimeout, 0.millis, stats)
-
- def apply(connection: Connection, query: String, params: Any*) = {
- val simplifiedQueryString = TimingOutStatsCollectingQueryFactory.simplifiedQuery(query)
- val (name, timeout) = queryInfo.getOrElse(simplifiedQueryString, ("default", defaultTimeout))
-
- new TimingOutStatsCollectingQuery(
- new TimingOutQuery(
- queryFactory(connection, query, params: _*),
- connection,
- timeout,
- cancelTimeout),
- name,
- stats)
- }
-}
-
-class TimingOutStatsCollectingQuery(query: Query, queryName: String, stats: StatsCollector) extends QueryProxy(query) {
- override def select[A](f: ResultSet => A) = {
- stats.incr("db-select-count", 1)
- stats.time("db-select-timing")(delegate(query.select(f)))
- }
-
- override def execute() = {
- stats.incr("db-execute-count", 1)
- stats.time("db-execute-timing")(delegate(query.execute()))
- }
-
- override def delegate[A](f: => A) = {
- stats.incr("db-query-count-" + queryName, 1)
- stats.time("db-timing") {
- stats.time("x-db-query-timing-" + queryName) {
- f
- }
- }
- }
-}
diff --git a/src/main/scala/com/twitter/querulous/test/FakeDatabase.scala b/src/main/scala/com/twitter/querulous/test/FakeDatabase.scala
index 8f5743e..635ad73 100644
--- a/src/main/scala/com/twitter/querulous/test/FakeDatabase.scala
+++ b/src/main/scala/com/twitter/querulous/test/FakeDatabase.scala
@@ -1,17 +1,20 @@
package com.twitter.querulous.test
import java.sql.Connection
-import com.twitter.xrayspecs.{Duration, Time}
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.util.{Duration, Time}
+import com.twitter.util.TimeConversions._
import com.twitter.querulous.database.Database
-class FakeDatabase(connection: Connection, latency: Duration) extends Database {
+class FakeDatabase(connection: Connection, before: Option[String => Unit]) extends Database {
+ def this(connection: Connection) = this(connection, None)
+ def this(connection: Connection, before: String => Unit) = this(connection, Some(before))
+
def open(): Connection = {
- Time.advance(latency)
+ before.foreach { _("open") }
connection
}
def close(connection: Connection) {
- Time.advance(latency)
+ before.foreach { _("close") }
}
}
diff --git a/src/main/scala/com/twitter/querulous/test/FakeQuery.scala b/src/main/scala/com/twitter/querulous/test/FakeQuery.scala
index a243ceb..b746244 100644
--- a/src/main/scala/com/twitter/querulous/test/FakeQuery.scala
+++ b/src/main/scala/com/twitter/querulous/test/FakeQuery.scala
@@ -12,4 +12,6 @@ class FakeQuery(resultSets: Seq[ResultSet]) extends Query {
}
override def execute() = 0
+
+ def addParams(params: Any*) = {}
}
diff --git a/src/main/scala/com/twitter/querulous/test/FakeQueryEvaluator.scala b/src/main/scala/com/twitter/querulous/test/FakeQueryEvaluator.scala
index f938080..17d60f2 100644
--- a/src/main/scala/com/twitter/querulous/test/FakeQueryEvaluator.scala
+++ b/src/main/scala/com/twitter/querulous/test/FakeQueryEvaluator.scala
@@ -1,15 +1,18 @@
package com.twitter.querulous.test
import java.sql.ResultSet
-import com.twitter.xrayspecs.Duration
-import com.twitter.querulous.evaluator.{Transaction, QueryEvaluator}
+import com.twitter.util.Duration
+import com.twitter.querulous.evaluator.{Transaction, QueryEvaluator, ParamsApplier}
+import com.twitter.querulous.query.{QueryClass, Query}
+
class FakeQueryEvaluator[A](trans: Transaction, resultSets: Seq[ResultSet]) extends QueryEvaluator {
- def select[A](query: String, params: Any*)(f: ResultSet => A) = resultSets.map(f)
- def selectOne[A](query: String, params: Any*)(f: ResultSet => A) = None
- def count(query: String, params: Any*) = 0
- def execute(query: String, params: Any*) = 0
+ def select[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = resultSets.map(f)
+ def selectOne[A](queryClass: QueryClass, query: String, params: Any*)(f: ResultSet => A) = None
+ def count(queryClass: QueryClass, query: String, params: Any*) = 0
+ def execute(queryClass: QueryClass, query: String, params: Any*) = 0
+ def executeBatch(queryClass: QueryClass, query: String)(f: ParamsApplier => Unit) = 0
def nextId(tableName: String) = 0
- def insert(query: String, params: Any*) = 0
+ def insert(queryClass: QueryClass, query: String, params: Any*) = 0
def transaction[T](f: Transaction => T) = f(trans)
}
diff --git a/src/main/scala/com/twitter/querulous/test/FakeStatsCollector.scala b/src/main/scala/com/twitter/querulous/test/FakeStatsCollector.scala
index c2777ba..255ae23 100644
--- a/src/main/scala/com/twitter/querulous/test/FakeStatsCollector.scala
+++ b/src/main/scala/com/twitter/querulous/test/FakeStatsCollector.scala
@@ -1,22 +1,24 @@
package com.twitter.querulous.test
import scala.collection.mutable.Map
-import com.twitter.xrayspecs.Time
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.util.Time
+import com.twitter.util.TimeConversions._
+import com.twitter.querulous.StatsCollector
+
class FakeStatsCollector extends StatsCollector {
val counts = Map[String, Int]()
val times = Map[String, Long]()
def incr(name: String, count: Int) = {
- counts + (name -> (count+counts.getOrElseUpdate(name, 0)))
+ counts += (name -> (count+counts.getOrElseUpdate(name, 0)))
}
def time[A](name: String)(f: => A): A = {
val start = Time.now
val rv = f
val end = Time.now
- times + (name -> ((end-start).inMillis + times.getOrElseUpdate(name, 0L)))
+ times += (name -> ((end-start).inMillis + times.getOrElseUpdate(name, 0L)))
rv
}
}
diff --git a/src/test/scala/com/twitter/querulous/TestEvaluator.scala b/src/test/scala/com/twitter/querulous/TestEvaluator.scala
index 068c0f9..cd44fe5 100644
--- a/src/test/scala/com/twitter/querulous/TestEvaluator.scala
+++ b/src/test/scala/com/twitter/querulous/TestEvaluator.scala
@@ -1,16 +1,58 @@
package com.twitter.querulous
-import net.lag.configgy.Configgy
import com.twitter.querulous.database.{MemoizingDatabaseFactory, SingleConnectionDatabaseFactory}
import com.twitter.querulous.query.SqlQueryFactory
-import com.twitter.querulous.evaluator.StandardQueryEvaluatorFactory
-import com.twitter.xrayspecs.Time
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.querulous.evaluator.{QueryEvaluator, StandardQueryEvaluatorFactory}
+import com.twitter.util.Eval
+import com.twitter.util.Time
+import com.twitter.util.TimeConversions._
+import java.io.File
+import java.util.concurrent.CountDownLatch
+import org.specs.Specification
+
+import config.Connection
+
+
+trait ConfiguredSpecification extends Specification {
+ val config = try {
+ Eval[Connection](new File("config/test.scala"))
+ } catch {
+ case e =>
+ e.printStackTrace()
+ throw e
+ }
+}
object TestEvaluator {
// val testDatabaseFactory = new MemoizingDatabaseFactory()
val testDatabaseFactory = new SingleConnectionDatabaseFactory
val testQueryFactory = new SqlQueryFactory
val testEvaluatorFactory = new StandardQueryEvaluatorFactory(testDatabaseFactory, testQueryFactory)
+
+ private val userEnv = System.getenv("DB_USERNAME")
+ private val passEnv = System.getenv("DB_PASSWORD")
+
+ def getDbLock(queryEvaluator: QueryEvaluator, lockName: String) = {
+ val returnLatch = new CountDownLatch(1)
+ val releaseLatch = new CountDownLatch(1)
+
+ val thread = new Thread() {
+ override def run() {
+ queryEvaluator.select("SELECT GET_LOCK('" + lockName + "', 1) AS rv") { row =>
+ returnLatch.countDown()
+ try {
+ releaseLatch.await()
+ } catch {
+ case _ =>
+ }
+ }
+ }
+ }
+
+ thread.start()
+ returnLatch.await()
+
+ releaseLatch
+ }
}
diff --git a/src/test/scala/com/twitter/querulous/TestRunner.scala b/src/test/scala/com/twitter/querulous/TestRunner.scala
deleted file mode 100644
index 83a18cf..0000000
--- a/src/test/scala/com/twitter/querulous/TestRunner.scala
+++ /dev/null
@@ -1,14 +0,0 @@
-package com.twitter.querulous
-
-import org.specs.runner.SpecsFileRunner
-import org.specs.util.Configuration
-import net.lag.configgy.Configgy
-
-object TestRunner extends SpecsFileRunner("src/test/scala/**/*.scala", ".*",
- System.getProperty("system", ".*"), System.getProperty("example", ".*")) {
-
- System.setProperty("stage", "test")
-
- Configgy.configure(System.getProperty("basedir") + "/config/" + System.getProperty("stage", "test") + ".conf")
-}
-
diff --git a/src/test/scala/com/twitter/querulous/integration/QuerySpec.scala b/src/test/scala/com/twitter/querulous/integration/QuerySpec.scala
index c9f572b..b177162 100644
--- a/src/test/scala/com/twitter/querulous/integration/QuerySpec.scala
+++ b/src/test/scala/com/twitter/querulous/integration/QuerySpec.scala
@@ -1,22 +1,29 @@
package com.twitter.querulous.integration
-import org.specs.Specification
-import net.lag.configgy.Configgy
-import com.twitter.xrayspecs.Time
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.util.Time
+import com.twitter.util.TimeConversions._
+import com.twitter.querulous.ConfiguredSpecification
+import com.twitter.querulous.TestEvaluator
import com.twitter.querulous.database.ApachePoolingDatabaseFactory
import com.twitter.querulous.query._
import com.twitter.querulous.evaluator.{StandardQueryEvaluatorFactory, QueryEvaluator}
+class QuerySpec extends ConfiguredSpecification {
+// Configgy.configure("config/" + System.getProperty("stage", "test") + ".conf")
-class QuerySpec extends Specification {
- Configgy.configure("config/" + System.getProperty("stage", "test") + ".conf")
-
+// val config = Configgy.config.configMap("db")
+// val username = config("username")
+// val password = config("password")
+// val queryEvaluator = testEvaluatorFactory("localhost", "db_test", username, password)
import TestEvaluator._
- val config = Configgy.config.configMap("db")
+ val queryEvaluator = testEvaluatorFactory(config)
"Query" should {
- val queryEvaluator = testEvaluatorFactory(config)
+ doBefore {
+ queryEvaluator.execute("CREATE TABLE IF NOT EXISTS foo(bar INT, baz INT)")
+ queryEvaluator.execute("TRUNCATE foo")
+ queryEvaluator.execute("INSERT INTO foo VALUES (1,1), (3,3)")
+ }
"with too many arguments" >> {
queryEvaluator.select("SELECT 1 FROM DUAL WHERE 1 IN (?)", 1, 2, 3) { r => 1 } must throwA[TooManyQueryParametersException]
@@ -26,15 +33,22 @@ class QuerySpec extends Specification {
queryEvaluator.select("SELECT 1 FROM DUAL WHERE 1 = ? OR 1 = ?", 1) { r => 1 } must throwA[TooFewQueryParametersException]
}
+ "in batch mode" >> {
+ queryEvaluator.executeBatch("UPDATE foo SET bar = ? WHERE bar = ?") { withParams =>
+ withParams("2", "1")
+ withParams("3", "3")
+ } mustEqual 2
+ }
+
"with just the right number of arguments" >> {
queryEvaluator.select("SELECT 1 FROM DUAL WHERE 1 IN (?)", List(1, 2, 3))(_.getInt(1)).toList mustEqual List(1)
}
"be backwards compatible" >> {
- val noOpts = testEvaluatorFactory("localhost", null, config("username"), config("password"))
+ val noOpts = testEvaluatorFactory(config.hostnames.toList, null, config.username, config.password)
noOpts.select("SELECT 1 FROM DUAL WHERE 1 IN (?)", List(1, 2, 3))(_.getInt(1)).toList mustEqual List(1)
- val noDBNameOrOpts = testEvaluatorFactory("localhost", config("username"), config("password"))
+ val noDBNameOrOpts = testEvaluatorFactory(config.hostnames.toList, config.username, config.password)
noDBNameOrOpts.select("SELECT 1 FROM DUAL WHERE 1 IN (?)", List(1, 2, 3))(_.getInt(1)).toList mustEqual List(1)
}
}
diff --git a/src/test/scala/com/twitter/querulous/integration/TimeoutSpec.scala b/src/test/scala/com/twitter/querulous/integration/TimeoutSpec.scala
index d177f7d..dbe92bb 100644
--- a/src/test/scala/com/twitter/querulous/integration/TimeoutSpec.scala
+++ b/src/test/scala/com/twitter/querulous/integration/TimeoutSpec.scala
@@ -1,51 +1,43 @@
package com.twitter.querulous.integration
import java.util.concurrent.CountDownLatch
-import org.specs.Specification
-import net.lag.configgy.Configgy
-import com.twitter.xrayspecs.Time
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.util.Time
+import com.twitter.util.TimeConversions._
+import com.twitter.querulous.TestEvaluator
import com.twitter.querulous.database.ApachePoolingDatabaseFactory
import com.twitter.querulous.query.{SqlQueryFactory, TimingOutQueryFactory, SqlQueryTimeoutException}
import com.twitter.querulous.evaluator.{StandardQueryEvaluatorFactory, QueryEvaluator}
+import com.twitter.querulous.ConfiguredSpecification
-class TimeoutSpec extends Specification {
- Configgy.configure("config/test.conf")
-
+class TimeoutSpec extends ConfiguredSpecification {
import TestEvaluator._
- val config = Configgy.config.configMap("db")
- val username = config("username")
- val password = config("password")
val timeout = 1.second
- val timingOutQueryFactory = new TimingOutQueryFactory(testQueryFactory, timeout)
+ val timingOutQueryFactory = new TimingOutQueryFactory(testQueryFactory, timeout, false)
+ val apacheDatabaseFactory = new ApachePoolingDatabaseFactory(10, 10, 1.second, 10.millis, false, 0.seconds)
val timingOutQueryEvaluatorFactory = new StandardQueryEvaluatorFactory(testDatabaseFactory, timingOutQueryFactory)
"Timeouts" should {
doBefore {
- testEvaluatorFactory("localhost", null, username, password).execute("CREATE DATABASE IF NOT EXISTS db_test")
+ testEvaluatorFactory(config.withoutDatabase).execute("CREATE DATABASE IF NOT EXISTS db_test")
}
"honor timeouts" in {
- val queryEvaluator1 = testEvaluatorFactory(List("localhost"), "db_test", username, password)
- val latch = new CountDownLatch(1)
+ val queryEvaluator1 = testEvaluatorFactory(config)
+ val dbLock = getDbLock(queryEvaluator1, "padlock")
+
val thread = new Thread() {
override def run() {
- queryEvaluator1.select("SELECT GET_LOCK('padlock', 1) AS rv") { row =>
- latch.countDown()
- try {
- Thread.sleep(60.seconds.inMillis)
- } catch {
- case _ =>
- }
- }
+ try {
+ Thread.sleep(60.seconds.inMillis)
+ } catch { case _ => () }
+ dbLock.countDown()
}
}
thread.start()
- latch.await()
- val queryEvaluator2 = timingOutQueryEvaluatorFactory(List("localhost"), "db_test", username, password)
+ val queryEvaluator2 = timingOutQueryEvaluatorFactory(config)
val start = Time.now
queryEvaluator2.select("SELECT GET_LOCK('padlock', 60) AS rv") { row => row.getInt("rv") } must throwA[SqlQueryTimeoutException]
val end = Time.now
diff --git a/src/test/scala/com/twitter/querulous/unit/AutoDisablingQueryEvaluatorSpec.scala b/src/test/scala/com/twitter/querulous/unit/AutoDisablingQueryEvaluatorSpec.scala
index 119dd5a..4430596 100644
--- a/src/test/scala/com/twitter/querulous/unit/AutoDisablingQueryEvaluatorSpec.scala
+++ b/src/test/scala/com/twitter/querulous/unit/AutoDisablingQueryEvaluatorSpec.scala
@@ -6,8 +6,8 @@ import java.sql.{ResultSet, SQLException, SQLIntegrityConstraintViolationExcepti
import com.mysql.jdbc.exceptions.MySQLIntegrityConstraintViolationException
import com.twitter.querulous.test.FakeQueryEvaluator
import com.twitter.querulous.evaluator.{AutoDisablingQueryEvaluator, Transaction}
-import com.twitter.xrayspecs.Time
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.util.Time
+import com.twitter.util.TimeConversions._
class AutoDisablingQueryEvaluatorSpec extends Specification with JMocker with ClassMocker {
@@ -18,8 +18,6 @@ class AutoDisablingQueryEvaluatorSpec extends Specification with JMocker with Cl
val disableDuration = 1.minute
val queryEvaluator = new FakeQueryEvaluator(trans, List(mock[ResultSet]))
val autoDisablingQueryEvaluator = new AutoDisablingQueryEvaluator(queryEvaluator, disableErrorCount, disableDuration)
- Time.freeze()
-
"when there are no failures" >> {
autoDisablingQueryEvaluator.select("SELECT 1 FROM DUAL") { _ => 1 } mustEqual List(1)
}
@@ -60,21 +58,23 @@ class AutoDisablingQueryEvaluatorSpec extends Specification with JMocker with Cl
}
"when there are more than disableErrorCount failures but disableDuration has elapsed" >> {
- var invocationCount = 0
+ Time.withCurrentTimeFrozen { time =>
+ var invocationCount = 0
- (0 until disableErrorCount + 1) foreach { i =>
+ (0 until disableErrorCount + 1) foreach { i =>
+ autoDisablingQueryEvaluator.select("SELECT 1 FROM DUAL") { resultSet =>
+ invocationCount += 1
+ throw new SQLException
+ } must throwA[SQLException]
+ }
+ invocationCount mustEqual disableErrorCount
+
+ time.advance(1.minute)
autoDisablingQueryEvaluator.select("SELECT 1 FROM DUAL") { resultSet =>
invocationCount += 1
- throw new SQLException
- } must throwA[SQLException]
- }
- invocationCount mustEqual disableErrorCount
-
- Time.advance(1.minute)
- autoDisablingQueryEvaluator.select("SELECT 1 FROM DUAL") { resultSet =>
- invocationCount += 1
+ }
+ invocationCount mustEqual disableErrorCount + 1
}
- invocationCount mustEqual disableErrorCount + 1
}
}
}
diff --git a/src/test/scala/com/twitter/querulous/unit/DatabaseSpec.scala b/src/test/scala/com/twitter/querulous/unit/DatabaseSpec.scala
new file mode 100644
index 0000000..8293d69
--- /dev/null
+++ b/src/test/scala/com/twitter/querulous/unit/DatabaseSpec.scala
@@ -0,0 +1,71 @@
+package com.twitter.querulous.unit
+
+import java.sql.{PreparedStatement, Connection, Types}
+import org.apache.commons.dbcp.{DelegatingConnection => DBCPConnection}
+import com.mysql.jdbc.{ConnectionImpl => MySQLConnection}
+import java.util.Properties
+import org.specs.mock.{ClassMocker, JMocker}
+import com.twitter.util.TimeConversions._
+import com.twitter.querulous.database._
+import com.twitter.querulous.ConfiguredSpecification
+
+
+class DatabaseSpec extends ConfiguredSpecification with JMocker with ClassMocker {
+ val defaultProps = Map("socketTimeout" -> "41", "connectTimeout" -> "42")
+
+ def mysqlConn(conn: Connection) = conn match {
+ case c: DBCPConnection =>
+ c.getInnermostDelegate.asInstanceOf[MySQLConnection]
+ case c: MySQLConnection => c
+ }
+
+ def testFactory(factory: DatabaseFactory) {
+ "allow specification of default query options" in {
+ val db = factory(config.hostnames.toList, null, config.username, config.password)
+ val props = mysqlConn(db.open).getProperties
+
+ props.getProperty("connectTimeout") mustEqual "42"
+ props.getProperty("socketTimeout") mustEqual "41"
+ }
+
+ "allow override of default query options" in {
+ val db = factory(
+ config.hostnames.toList,
+ null,
+ config.username,
+ config.password,
+ Map("connectTimeout" -> "43"))
+ val props = mysqlConn(db.open).getProperties
+
+ props.getProperty("connectTimeout") mustEqual "43"
+ props.getProperty("socketTimeout") mustEqual "41"
+ }
+ }
+
+ "SingleConnectionDatabaseFactory" should {
+ val factory = new SingleConnectionDatabaseFactory(defaultProps)
+ testFactory(factory)
+ }
+
+ "ApachePoolingDatabaseFactory" should {
+ val factory = new ApachePoolingDatabaseFactory(
+ 10, 10, 1.second, 10.millis, false, 0.seconds, defaultProps
+ )
+
+ testFactory(factory)
+ }
+
+ "Database#url" should {
+ val fake = new Object with Database {
+ def open() = null
+ def close(connection: Connection) = ()
+ }.asInstanceOf[{def url(a: List[String], b:String, c:Map[String, String]): String}]
+
+ "add default unicode urlOptions" in {
+ val url = fake.url(List("host"), "db", Map())
+
+ url mustMatch "useUnicode=true"
+ url mustMatch "characterEncoding=UTF-8"
+ }
+ }
+}
diff --git a/src/test/scala/com/twitter/querulous/unit/MemoizingDatabaseFactorySpec.scala b/src/test/scala/com/twitter/querulous/unit/MemoizingDatabaseFactorySpec.scala
index f5a61aa..701ba80 100644
--- a/src/test/scala/com/twitter/querulous/unit/MemoizingDatabaseFactorySpec.scala
+++ b/src/test/scala/com/twitter/querulous/unit/MemoizingDatabaseFactorySpec.scala
@@ -18,8 +18,8 @@ class MemoizingDatabaseFactorySpec extends Specification with JMocker {
val memoizingDatabase = new MemoizingDatabaseFactory(databaseFactory)
expect {
- one(databaseFactory).apply(hosts, "bar", username, password, null) willReturn database1
- one(databaseFactory).apply(hosts, "baz", username, password, null) willReturn database2
+ one(databaseFactory).apply(hosts, "bar", username, password, Map.empty) willReturn database1
+ one(databaseFactory).apply(hosts, "baz", username, password, Map.empty) willReturn database2
}
memoizingDatabase(hosts, "bar", username, password) mustBe database1
memoizingDatabase(hosts, "bar", username, password) mustBe database1
diff --git a/src/test/scala/com/twitter/querulous/unit/QueryEvaluatorSpec.scala b/src/test/scala/com/twitter/querulous/unit/QueryEvaluatorSpec.scala
index 9497d53..6099a77 100644
--- a/src/test/scala/com/twitter/querulous/unit/QueryEvaluatorSpec.scala
+++ b/src/test/scala/com/twitter/querulous/unit/QueryEvaluatorSpec.scala
@@ -2,30 +2,24 @@ package com.twitter.querulous.unit
import java.sql.{SQLException, DriverManager, Connection}
import scala.collection.mutable
-import net.lag.configgy.{Config, Configgy}
import com.mysql.jdbc.exceptions.MySQLIntegrityConstraintViolationException
+import com.twitter.querulous.{StatsCollector, TestEvaluator}
import com.twitter.querulous.database.{ApachePoolingDatabaseFactory, MemoizingDatabaseFactory, Database}
import com.twitter.querulous.evaluator.{StandardQueryEvaluator, StandardQueryEvaluatorFactory, QueryEvaluator}
import com.twitter.querulous.query._
import com.twitter.querulous.test.FakeDatabase
-import com.twitter.xrayspecs.Time
-import com.twitter.xrayspecs.TimeConversions._
-import org.specs.Specification
+import com.twitter.querulous.ConfiguredSpecification
+import com.twitter.util.Time
+import com.twitter.util.TimeConversions._
import org.specs.mock.{ClassMocker, JMocker}
-class QueryEvaluatorSpec extends Specification with JMocker with ClassMocker {
- Configgy.configure("config/test.conf")
+class QueryEvaluatorSpec extends ConfiguredSpecification with JMocker with ClassMocker {
import TestEvaluator._
- val config = Configgy.config.configMap("db")
- val username = config("username")
- val password = config("password")
- val urlOptions = config.configMap("url_options").asMap.asInstanceOf[Map[String, String]]
-
"QueryEvaluator" should {
- val queryEvaluator = testEvaluatorFactory("localhost", "db_test", username, password, urlOptions)
- val rootQueryEvaluator = testEvaluatorFactory("localhost", null, username, password, urlOptions)
+ val queryEvaluator = testEvaluatorFactory(config)
+ val rootQueryEvaluator = testEvaluatorFactory(config.withoutDatabase)
val queryFactory = new SqlQueryFactory
doBefore {
@@ -36,27 +30,9 @@ class QueryEvaluatorSpec extends Specification with JMocker with ClassMocker {
queryEvaluator.execute("DROP TABLE IF EXISTS foo")
}
- "fromConfig" in {
- val stats = mock[StatsCollector]
- QueryFactory.fromConfig(Config.fromMap(Map.empty), None) must haveClass[SqlQueryFactory]
- QueryFactory.fromConfig(Config.fromMap(Map.empty), Some(stats)) must
- haveClass[StatsCollectingQueryFactory]
- QueryFactory.fromConfig(Config.fromMap(Map("query_timeout_default" -> "10")), None) must
- haveClass[TimingOutQueryFactory]
- QueryFactory.fromConfig(Config.fromMap(Map("retries" -> "10")), None) must
- haveClass[RetryingQueryFactory]
- QueryFactory.fromConfig(Config.fromMap(Map("debug" -> "true")), None) must
- haveClass[DebuggingQueryFactory]
-
- val config = new Config()
- config.setConfigMap("queries", new Config())
- config("query_timeout_default") = "10"
- QueryFactory.fromConfig(config, Some(stats)) must haveClass[TimingOutStatsCollectingQueryFactory]
- }
-
"connection pooling" in {
val connection = mock[Connection]
- val database = new FakeDatabase(connection, 1.millis)
+ val database = new FakeDatabase(connection)
"transactionally" >> {
val queryEvaluator = new StandardQueryEvaluator(database, queryFactory)
@@ -95,7 +71,8 @@ class QueryEvaluatorSpec extends Specification with JMocker with ClassMocker {
"fallback to a read slave" in {
// should always succeed if you have the right mysql driver.
- val queryEvaluator = testEvaluatorFactory(List("localhost:12349", "localhost"), "db_test", username, password)
+ val queryEvaluator = testEvaluatorFactory(
+ "localhost:12349" :: config.hostnames.toList, config.database, config.username, config.password)
queryEvaluator.selectOne("SELECT 1") { row => row.getInt(1) }.toList mustEqual List(1)
queryEvaluator.execute("CREATE TABLE foo (id INT)") must throwA[SQLException]
}
diff --git a/src/test/scala/com/twitter/querulous/unit/SqlQuerySpec.scala b/src/test/scala/com/twitter/querulous/unit/SqlQuerySpec.scala
index ae9e5bc..0515022 100644
--- a/src/test/scala/com/twitter/querulous/unit/SqlQuerySpec.scala
+++ b/src/test/scala/com/twitter/querulous/unit/SqlQuerySpec.scala
@@ -1,12 +1,11 @@
package com.twitter.querulous.unit
-import java.sql.{PreparedStatement, Connection, Types}
+import java.sql.{PreparedStatement, Connection, Types, SQLException}
import org.specs.Specification
import org.specs.mock.{ClassMocker, JMocker}
import com.twitter.querulous.query.NullValues._
import com.twitter.querulous.query.{NullValues, SqlQuery}
-
class SqlQuerySpec extends Specification with JMocker with ClassMocker {
"SqlQuery" should {
"typecast" in {
@@ -23,6 +22,50 @@ class SqlQuerySpec extends Specification with JMocker with ClassMocker {
}
new SqlQuery(connection, "SELECT * FROM foo WHERE id IN (?)", List(1, 2, 3)).select { _ => 1 }
}
+
+ "arrays of pairs" in {
+ val connection = mock[Connection]
+ val statement = mock[PreparedStatement]
+ expect {
+ one(connection).prepareStatement("SELECT * FROM foo WHERE (id, uid) IN ((?,?),(?,?))") willReturn statement
+ one(statement).setInt(1, 1) then
+ one(statement).setInt(2, 2) then
+ one(statement).setInt(3, 3) then
+ one(statement).setInt(4, 4) then
+ one(statement).executeQuery() then
+ one(statement).getResultSet
+ }
+ new SqlQuery(connection, "SELECT * FROM foo WHERE (id, uid) IN (?)", List((1, 2), (3, 4))).select { _ => 1 }
+ }
+
+ "arrays of tuple3s" in {
+ val connection = mock[Connection]
+ val statement = mock[PreparedStatement]
+ expect {
+ one(connection).prepareStatement("SELECT * FROM foo WHERE (id1, id2, id3) IN ((?,?,?))") willReturn statement
+ one(statement).setInt(1, 1) then
+ one(statement).setInt(2, 2) then
+ one(statement).setInt(3, 3) then
+ one(statement).executeQuery() then
+ one(statement).getResultSet
+ }
+ new SqlQuery(connection, "SELECT * FROM foo WHERE (id1, id2, id3) IN (?)", List((1, 2, 3))).select { _ => 1 }
+ }
+
+ "arrays of tuple4s" in {
+ val connection = mock[Connection]
+ val statement = mock[PreparedStatement]
+ expect {
+ one(connection).prepareStatement("SELECT * FROM foo WHERE (id1, id2, id3, id4) IN ((?,?,?,?))") willReturn statement
+ one(statement).setInt(1, 1) then
+ one(statement).setInt(2, 2) then
+ one(statement).setInt(3, 3) then
+ one(statement).setInt(4, 4) then
+ one(statement).executeQuery() then
+ one(statement).getResultSet
+ }
+ new SqlQuery(connection, "SELECT * FROM foo WHERE (id1, id2, id3, id4) IN (?)", List((1, 2, 3, 4))).select { _ => 1 }
+ }
}
"create a query string" in {
@@ -57,5 +100,34 @@ class SqlQuerySpec extends Specification with JMocker with ClassMocker {
new SqlQuery(connection, queryString, NullString, NullInt, NullDouble, NullBoolean, NullLong, NullValues(Types.VARBINARY))
}
+
+ "handle exceptions" in {
+ val queryString = "INSERT INTO TABLE (col1) VALUES (?)"
+ val connection = mock[Connection]
+ val statement = mock[PreparedStatement]
+ val unrecognizedType = connection
+ "throw illegal argument exception if type passed in is unrecognized" in {
+ expect {
+ one(connection).prepareStatement(queryString) willReturn statement
+ }
+ new SqlQuery(connection, queryString, unrecognizedType) must throwAn[IllegalArgumentException]
+ }
+ "throw chained-exception" in {
+ val expectedCauseException = new SQLException("")
+ expect {
+ one(connection).prepareStatement(queryString) willReturn statement then
+ one(statement).setString(1, "one") willThrow expectedCauseException
+ }
+ try {
+ new SqlQuery(connection, queryString, "one")
+ fail("should throw")
+ } catch {
+ case e: Exception => {
+ e.getCause must beEqualTo(expectedCauseException)
+ }
+ case _ => fail("unknown throwable")
+ }
+ }
+ }
}
}
diff --git a/src/test/scala/com/twitter/querulous/unit/StatsCollectingDatabaseSpec.scala b/src/test/scala/com/twitter/querulous/unit/StatsCollectingDatabaseSpec.scala
index ba6d1fe..a4ccab2 100644
--- a/src/test/scala/com/twitter/querulous/unit/StatsCollectingDatabaseSpec.scala
+++ b/src/test/scala/com/twitter/querulous/unit/StatsCollectingDatabaseSpec.scala
@@ -4,29 +4,47 @@ import scala.collection.mutable.Map
import java.sql.Connection
import org.specs.Specification
import org.specs.mock.{ClassMocker, JMocker}
-import com.twitter.querulous.database.StatsCollectingDatabase
+import com.twitter.querulous.database.{SqlDatabaseTimeoutException, StatsCollectingDatabase}
import com.twitter.querulous.test.{FakeStatsCollector, FakeDatabase}
-import com.twitter.xrayspecs.Time
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.util.Time
+import com.twitter.util.TimeConversions._
class StatsCollectingDatabaseSpec extends Specification with JMocker with ClassMocker {
"StatsCollectingDatabase" should {
- Time.freeze()
val latency = 1.second
val connection = mock[Connection]
val stats = new FakeStatsCollector
- val pool = new StatsCollectingDatabase(new FakeDatabase(connection, latency), stats)
+ def pool(callback: String => Unit) = new StatsCollectingDatabase(new FakeDatabase(connection, callback), stats)
"collect stats" in {
"when closing" >> {
- pool.close(connection)
- stats.times("database-close-timing") mustEqual latency.inMillis
+ Time.withCurrentTimeFrozen { time =>
+ pool(s => time.advance(latency)).close(connection)
+ stats.times("db-close-timing") mustEqual latency.inMillis
+ }
}
"when opening" >> {
- pool.open()
- stats.times("database-open-timing") mustEqual latency.inMillis
+ Time.withCurrentTimeFrozen { time =>
+ pool(s => time.advance(latency)).open()
+ stats.times("db-open-timing") mustEqual latency.inMillis
+ }
+ }
+ }
+
+ "collect timeout stats" in {
+ val e = new SqlDatabaseTimeoutException("foo", 0.seconds)
+ "when closing" >> {
+ pool(s => throw e).close(connection) must throwA[SqlDatabaseTimeoutException]
+ stats.counts("db-close-timeout-count") mustEqual 1
+ }
+
+ "when opening" >> {
+ Time.withCurrentTimeFrozen { time =>
+ pool(s => throw e).open() must throwA[SqlDatabaseTimeoutException]
+ stats.counts("db-open-timeout-count") mustEqual 1
+ }
}
}
}
diff --git a/src/test/scala/com/twitter/querulous/unit/StatsCollectingQuerySpec.scala b/src/test/scala/com/twitter/querulous/unit/StatsCollectingQuerySpec.scala
index f062da8..c8d4201 100644
--- a/src/test/scala/com/twitter/querulous/unit/StatsCollectingQuerySpec.scala
+++ b/src/test/scala/com/twitter/querulous/unit/StatsCollectingQuerySpec.scala
@@ -3,31 +3,45 @@ package com.twitter.querulous.unit
import java.sql.ResultSet
import org.specs.Specification
import org.specs.mock.JMocker
-import com.twitter.querulous.query.StatsCollectingQuery
+import com.twitter.querulous.query.{QueryClass, SqlQueryTimeoutException, StatsCollectingQuery}
import com.twitter.querulous.test.{FakeQuery, FakeStatsCollector}
-import com.twitter.xrayspecs.Time
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.util.Time
+import com.twitter.util.TimeConversions._
class StatsCollectingQuerySpec extends Specification with JMocker {
"StatsCollectingQuery" should {
- Time.freeze()
- val latency = 1.second
- val testQuery = new FakeQuery(List(mock[ResultSet])) {
- override def select[A](f: ResultSet => A) = {
- Time.advance(latency)
- super.select(f)
+ "collect stats" in {
+ Time.withCurrentTimeFrozen { time =>
+ val latency = 1.second
+ val stats = new FakeStatsCollector
+ val testQuery = new FakeQuery(List(mock[ResultSet])) {
+ override def select[A](f: ResultSet => A) = {
+ time.advance(latency)
+ super.select(f)
+ }
+ }
+ val statsCollectingQuery = new StatsCollectingQuery(testQuery, QueryClass.Select, stats)
+
+ statsCollectingQuery.select { _ => 1 } mustEqual List(1)
+
+ stats.counts("db-select-count") mustEqual 1
+ stats.times("db-timing") mustEqual latency.inMillis
}
}
- "collect stats" in {
- val stats = new FakeStatsCollector
- val statsCollectingQuery = new StatsCollectingQuery(testQuery, stats)
+ "collect timeout stats" in {
+ Time.withCurrentTimeFrozen { time =>
+ val stats = new FakeStatsCollector
+ val testQuery = new FakeQuery(List(mock[ResultSet]))
+ val statsCollectingQuery = new StatsCollectingQuery(testQuery, QueryClass.Select, stats)
+ val e = new SqlQueryTimeoutException(0.seconds)
- statsCollectingQuery.select { _ => 1 } mustEqual List(1)
+ statsCollectingQuery.select { _ => throw e } must throwA[SqlQueryTimeoutException]
- stats.counts("db-select-count") mustEqual 1
- stats.times("db-timing") mustEqual latency.inMillis
+ stats.counts("db-query-timeout-count") mustEqual 1
+ stats.counts("db-query-" + QueryClass.Select.name + "-timeout-count") mustEqual 1
+ }
}
}
}
diff --git a/src/test/scala/com/twitter/querulous/unit/TimingOutDatabaseSpec.scala b/src/test/scala/com/twitter/querulous/unit/TimingOutDatabaseSpec.scala
index 0785194..b1e8f3a 100644
--- a/src/test/scala/com/twitter/querulous/unit/TimingOutDatabaseSpec.scala
+++ b/src/test/scala/com/twitter/querulous/unit/TimingOutDatabaseSpec.scala
@@ -1,22 +1,23 @@
package com.twitter.querulous.unit
+import com.twitter.querulous._
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.sql.Connection
import com.twitter.querulous.TimeoutException
import com.twitter.querulous.database.{SqlDatabaseTimeoutException, TimingOutDatabase, Database}
-import com.twitter.xrayspecs.Time
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.util.Time
+import com.twitter.util.TimeConversions._
import org.specs.Specification
import org.specs.mock.{JMocker, ClassMocker}
class TimingOutDatabaseSpec extends Specification with JMocker with ClassMocker {
"TimingOutDatabaseSpec" should {
- Time.reset()
val latch = new CountDownLatch(1)
val timeout = 1.second
var shouldWait = false
val connection = mock[Connection]
+ val future = new FutureTimeout(1, 1)
val database = new Database {
def open() = {
if (shouldWait) latch.await(100.seconds.inMillis, TimeUnit.MILLISECONDS)
@@ -29,7 +30,7 @@ class TimingOutDatabaseSpec extends Specification with JMocker with ClassMocker
// one(connection).close()
}
- val timingOutDatabase = new TimingOutDatabase(database, List("dbhost"), "dbname", 1, 1, timeout, timeout, 1)
+ val timingOutDatabase = new TimingOutDatabase(database, List("dbhost"), "dbname", future, timeout, 1)
shouldWait = true
"timeout" in {
diff --git a/src/test/scala/com/twitter/querulous/unit/TimingOutQuerySpec.scala b/src/test/scala/com/twitter/querulous/unit/TimingOutQuerySpec.scala
index 1b93de8..6bc28cd 100644
--- a/src/test/scala/com/twitter/querulous/unit/TimingOutQuerySpec.scala
+++ b/src/test/scala/com/twitter/querulous/unit/TimingOutQuerySpec.scala
@@ -1,22 +1,22 @@
package com.twitter.querulous.unit
import java.sql.ResultSet
-import net.lag.configgy.Configgy
import org.specs.Specification
import org.specs.mock.{JMocker, ClassMocker}
+import com.twitter.querulous.TestEvaluator
import com.twitter.querulous.test.FakeQuery
import com.twitter.querulous.query.{TimingOutQuery, SqlQueryTimeoutException}
-import com.twitter.xrayspecs.Duration
-import com.twitter.xrayspecs.TimeConversions._
+import com.twitter.querulous.ConfiguredSpecification
+import com.twitter.util.Duration
+import com.twitter.util.TimeConversions._
import java.util.concurrent.{CountDownLatch, TimeUnit}
-class TimingOutQuerySpec extends Specification with JMocker with ClassMocker {
+object TimingOutQuerySpec extends ConfiguredSpecification with JMocker with ClassMocker {
"TimingOutQuery" should {
- val config = Configgy.config.configMap("db")
- val connection = TestEvaluator.testDatabaseFactory(List("localhost"), config("username"), config("password")).open()
+ val connection = TestEvaluator.testDatabaseFactory(
+ config.hostnames.toList, config.username, config.password).open()
val timeout = 1.second
- val cancelTimeout = 0.millis
val resultSet = mock[ResultSet]
"timeout" in {
@@ -29,7 +29,7 @@ class TimingOutQuerySpec extends Specification with JMocker with ClassMocker {
super.select(f)
}
}
- val timingOutQuery = new TimingOutQuery(query, connection, timeout, cancelTimeout)
+ val timingOutQuery = new TimingOutQuery(query, connection, timeout, true)
timingOutQuery.select { r => 1 } must throwA[SqlQueryTimeoutException]
latch.getCount mustEqual 0
@@ -40,7 +40,7 @@ class TimingOutQuerySpec extends Specification with JMocker with ClassMocker {
val query = new FakeQuery(List(resultSet)) {
override def cancel() = { latch.countDown() }
}
- val timingOutQuery = new TimingOutQuery(query, connection, timeout, cancelTimeout)
+ val timingOutQuery = new TimingOutQuery(query, connection, timeout, true)
timingOutQuery.select { r => 1 }
latch.getCount mustEqual 1
diff --git a/src/test/scala/com/twitter/querulous/unit/TimingOutStatsCollectingQuerySpec.scala b/src/test/scala/com/twitter/querulous/unit/TimingOutStatsCollectingQuerySpec.scala
deleted file mode 100644
index 8bef8d7..0000000
--- a/src/test/scala/com/twitter/querulous/unit/TimingOutStatsCollectingQuerySpec.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-package com.twitter.querulous.unit
-
-import java.sql.ResultSet
-import org.specs.Specification
-import org.specs.mock.{ClassMocker, JMocker}
-import com.twitter.querulous.test.{FakeQuery, FakeStatsCollector}
-import com.twitter.querulous.query.TimingOutStatsCollectingQuery
-import com.twitter.xrayspecs.Time
-import com.twitter.xrayspecs.TimeConversions._
-
-
-class TimingOutStatsCollectingQuerySpec extends Specification with JMocker {
- "TimingOutStatsCollectingQuery" should {
- Time.freeze()
- val latency = 1.second
- val testQuery = new FakeQuery(List(mock[ResultSet])) {
- override def select[A](f: ResultSet => A) = {
- Time.advance(latency)
- super.select(f)
- }
- override def execute() = {
- Time.advance(latency)
- super.execute()
- }
- }
-
- "collect stats" in {
- val stats = new FakeStatsCollector
-
- "selects" >> {
- val query = new TimingOutStatsCollectingQuery(testQuery, "selectTest", stats)
- query.select { _ => 1 } mustEqual List(1)
- stats.times("db-select-timing") mustEqual latency.inMillis
- stats.times("x-db-query-timing-selectTest") mustEqual latency.inMillis
- stats.counts("db-select-count") mustEqual 1
- }
-
- "executes" >> {
- val query = new TimingOutStatsCollectingQuery(testQuery, "executeTest", stats)
- query.execute
- stats.times("db-execute-timing") mustEqual latency.inMillis
- stats.times("x-db-query-timing-executeTest") mustEqual latency.inMillis
- stats.counts("db-execute-count") mustEqual 1
- }
-
- "globally" >> {
- val query = new TimingOutStatsCollectingQuery(testQuery, "default", stats)
- query.select { _ => 1 } mustEqual List(1)
- stats.times("db-timing") mustEqual latency.inMillis
- }
- }
- }
-}
-