diff --git a/build.sbt b/build.sbt index 5fed3f8..28e36b4 100644 --- a/build.sbt +++ b/build.sbt @@ -38,3 +38,6 @@ ThisBuild / assemblyMergeStrategy := { } run / fork := true + +Test / fork := true +Test / forkOptions := ForkOptions().withRunJVMOptions(Vector("-Djava.awt.headless=true")) diff --git a/src/main/java/org/limium/picoserve/Server.java b/src/main/java/org/limium/picoserve/Server.java index f2c3cd0..1e13820 100644 --- a/src/main/java/org/limium/picoserve/Server.java +++ b/src/main/java/org/limium/picoserve/Server.java @@ -83,6 +83,10 @@ public String getMethod() { return exchange.getRequestMethod(); } + public Map> getHeaders() { + return exchange.getRequestHeaders(); + } + public Map> getQueryParams() { final var query = exchange.getRequestURI().getQuery(); final var params = parseParams(query); diff --git a/src/main/scala/lc/Main.scala b/src/main/scala/lc/Main.scala index e45115b..8924732 100644 --- a/src/main/scala/lc/Main.scala +++ b/src/main/scala/lc/Main.scala @@ -5,12 +5,17 @@ import lc.server.Server import lc.background.BackgroundTask import lc.database.Statements -class LCFramework { +class LCFramework(authKey: Option[String] = sys.env.get("AUTH_KEY")) { private var backgroundTask: Option[BackgroundTask] = None private var server: Option[Server] = None def start(configFilePath: String = "data/config.json"): Unit = { val config = new Config(configFilePath) + + if (config.authRequired && authKey.isEmpty) { + throw new Exception("AUTH_KEY environment variable is not specified, but authRequired is true.") + } + Statements.maxAttempts = config.maxAttempts val captchaProviders = new CaptchaProviders(config = config) val captchaManager = new CaptchaManager(config = config, captchaProviders = captchaProviders) @@ -23,7 +28,9 @@ class LCFramework { port = config.port, captchaManager = captchaManager, playgroundEnabled = config.playgroundEnabled, - corsHeader = config.corsHeader + corsHeader = config.corsHeader, + authRequired = config.authRequired, + authKey = authKey ) srv.start() server = Some(srv) diff --git a/src/main/scala/lc/core/config.scala b/src/main/scala/lc/core/config.scala index 77e40c3..b254cc2 100644 --- a/src/main/scala/lc/core/config.scala +++ b/src/main/scala/lc/core/config.scala @@ -51,6 +51,7 @@ class Config(configFilePath: String) { val playgroundEnabled: Boolean = configFields.playgroundEnabledBool.getOrElse(true) val corsHeader: String = configFields.corsHeader.getOrElse("") val maxAttempts: Int = Math.max(1, (configFields.maxAttemptsRatioFloat.getOrElse(0.01f) * bufferCount).toInt) + val authRequired: Boolean = configFields.authRequiredBool.getOrElse(false) val captchaConfig: List[CaptchaConfig] = appConfig.captchas val allowedLevels: Set[String] = captchaConfig.flatMap(_.allowedLevels).toSet @@ -70,6 +71,7 @@ class Config(configFilePath: String) { playgroundEnabled = Some(true), corsHeader = Some(""), maxAttemptsRatio = Some(0.01f), + authRequired = Some(false), captchas = List( CaptchaConfig( name = "FilterChallenge", diff --git a/src/main/scala/lc/core/models.scala b/src/main/scala/lc/core/models.scala index 7520436..4771287 100644 --- a/src/main/scala/lc/core/models.scala +++ b/src/main/scala/lc/core/models.scala @@ -95,11 +95,12 @@ case class AppConfig( playgroundEnabled: Option[Boolean] = None, corsHeader: Option[String] = None, maxAttemptsRatio: Option[Float] = None, + authRequired: Option[Boolean] = None, captchas: List[CaptchaConfig] = List.empty ) { def toConfigField: ConfigField = ConfigField( port, address, bufferCount, seed, captchaExpiryTimeLimit, - threadDelay, playgroundEnabled, corsHeader, maxAttemptsRatio + threadDelay, playgroundEnabled, corsHeader, maxAttemptsRatio, authRequired ) } object AppConfig { @@ -115,7 +116,8 @@ case class ConfigField( threadDelay: Option[Int] = None, playgroundEnabled: Option[Boolean] = None, corsHeader: Option[String] = None, - maxAttemptsRatio: Option[Float] = None + maxAttemptsRatio: Option[Float] = None, + authRequired: Option[Boolean] = None ) { lazy val portInt: Option[Int] = port lazy val bufferCountInt: Option[Int] = bufferCount @@ -124,4 +126,5 @@ case class ConfigField( lazy val threadDelayInt: Option[Int] = threadDelay lazy val maxAttemptsRatioFloat: Option[Float] = maxAttemptsRatio lazy val playgroundEnabledBool: Option[Boolean] = playgroundEnabled.map(_ || false) + lazy val authRequiredBool: Option[Boolean] = authRequired.map(_ || false) } diff --git a/src/main/scala/lc/database/DB.scala b/src/main/scala/lc/database/DB.scala index edb4da3..e716d16 100644 --- a/src/main/scala/lc/database/DB.scala +++ b/src/main/scala/lc/database/DB.scala @@ -4,7 +4,7 @@ import java.sql.{Connection, DriverManager, Statement} class DBConn() { val con: Connection = - DriverManager.getConnection("jdbc:h2:./data/H2/captcha3;MAX_COMPACT_TIME=8000;DB_CLOSE_ON_EXIT=FALSE", "sa", "") + DriverManager.getConnection("jdbc:h2:./data/H2/captcha3;MAX_COMPACT_TIME=8000;DB_CLOSE_ON_EXIT=FALSE;DB_CLOSE_DELAY=-1", "sa", "") def getStatement(): Statement = { con.createStatement() diff --git a/src/main/scala/lc/server/Server.scala b/src/main/scala/lc/server/Server.scala index 25ed70b..83f0d4b 100644 --- a/src/main/scala/lc/server/Server.scala +++ b/src/main/scala/lc/server/Server.scala @@ -18,12 +18,29 @@ class Server( port: Int, captchaManager: CaptchaManager, playgroundEnabled: Boolean, - corsHeader: String + corsHeader: String, + authRequired: Boolean = false, + authKey: Option[String] = None ) { var headerMap: util.Map[String, util.List[String]] = null if (corsHeader.nonEmpty) { headerMap = Map("Access-Control-Allow-Origin" -> List(corsHeader).asJava).asJava } + + private def checkAuth(request: picoserve.Server#Request): Boolean = { + if (!authRequired) return true + val headers = request.getHeaders() + if (headers != null && headers.containsKey("Auth")) { + val authHeaderValues = headers.get("Auth") + if (authHeaderValues != null && authHeaderValues.size() > 0) { + val authHeader = authHeaderValues.get(0) + val expectedKey = authKey.getOrElse("") + return authHeader == expectedKey + } + } + false + } + val serverBuilder: ServerBuilder = picoserve.Server .builder() .address(new InetSocketAddress(address, port)) @@ -31,42 +48,54 @@ class Server( .POST( "/v2/captcha", (request) => { - val bodyStr = request.getBodyString().trim.replaceAll("\u0000", "") - val paramEither = Parameters.codec.decode(ByteBuffer.wrap(bodyStr.getBytes("UTF-8"))) - paramEither match { - case Right(param) => - val id = captchaManager.getChallenge(param) - getResponse(id, headerMap) - case Left(err) => - getResponse(Left(Error("Invalid parameters: " + err.toString)), headerMap) + if (!checkAuth(request)) { + new StringResponse(401, "Unauthorized", headerMap) + } else { + val bodyStr = request.getBodyString().trim.replaceAll("\u0000", "") + val paramEither = Parameters.codec.decode(ByteBuffer.wrap(bodyStr.getBytes("UTF-8"))) + paramEither match { + case Right(param) => + val id = captchaManager.getChallenge(param) + getResponse(id, headerMap) + case Left(err) => + getResponse(Left(Error("Invalid parameters: " + err.toString)), headerMap) + } } } ) .GET( "/v2/media", (request) => { - val params = request.getQueryParams() - val result = if (params.containsKey("id")) { - val paramId = params.get("id").get(0) - val id = Id(paramId) - captchaManager.getCaptcha(id) + if (!checkAuth(request)) { + new StringResponse(401, "Unauthorized", headerMap) } else { - Left(Error(ErrorMessageEnum.INVALID_PARAM.toString + "=> id")) + val params = request.getQueryParams() + val result = if (params.containsKey("id")) { + val paramId = params.get("id").get(0) + val id = Id(paramId) + captchaManager.getCaptcha(id) + } else { + Left(Error(ErrorMessageEnum.INVALID_PARAM.toString + "=> id")) + } + getResponse(result, headerMap) } - getResponse(result, headerMap) } ) .POST( "/v2/answer", (request) => { - val bodyStr = request.getBodyString().trim.replaceAll("\u0000", "") - val answerEither = Answer.codec.decode(ByteBuffer.wrap(bodyStr.getBytes("UTF-8"))) - answerEither match { - case Right(answer) => - val result = captchaManager.checkAnswer(answer) - getResponse(result, headerMap) - case Left(err) => - getResponse(Left(Error("Invalid answer format: " + err.toString)), headerMap) + if (!checkAuth(request)) { + new StringResponse(401, "Unauthorized", headerMap) + } else { + val bodyStr = request.getBodyString().trim.replaceAll("\u0000", "") + val answerEither = Answer.codec.decode(ByteBuffer.wrap(bodyStr.getBytes("UTF-8"))) + answerEither match { + case Right(answer) => + val result = captchaManager.checkAnswer(answer) + getResponse(result, headerMap) + case Left(err) => + getResponse(Left(Error("Invalid answer format: " + err.toString)), headerMap) + } } } ) diff --git a/src/test/scala/lc/ServerAuthSpec.scala b/src/test/scala/lc/ServerAuthSpec.scala new file mode 100644 index 0000000..c25aec1 --- /dev/null +++ b/src/test/scala/lc/ServerAuthSpec.scala @@ -0,0 +1,64 @@ +package lc.server + +import org.scalatest.funsuite.AnyFunSuite +import java.net.{HttpURLConnection, URL} +import java.io.{BufferedReader, InputStreamReader, OutputStreamWriter} +import lc.LCFramework +import scala.jdk.CollectionConverters._ + +class ServerAuthSpec extends AnyFunSuite { + + test("Server should require auth header when authRequired is true") { + val authFramework = new LCFramework(authKey = Some("secret123")) + // Ensure DB is not concurrently accessed by running tests sequentially + // The previous failure was due to parallel test execution and the embedded H2 DB getting closed. + authFramework.start("tests/auth-config.json") + Thread.sleep(2000) + + try { + val url = new URL("http://localhost:8889/v2/captcha") + + // 1. Test without auth header + val connection1 = url.openConnection().asInstanceOf[HttpURLConnection] + connection1.setRequestMethod("POST") + connection1.setRequestProperty("Content-Type", "application/json") + connection1.setDoOutput(true) + val payload = """{"level":"debug","media":"image/png","input_type":"text","size":"350x100"}""" + val out1 = new OutputStreamWriter(connection1.getOutputStream) + out1.write(payload) + out1.close() + + var responseCode = connection1.getResponseCode + assert(responseCode == 401, s"Expected 401 but got $responseCode") + + // 2. Test with invalid auth header + val connection2 = url.openConnection().asInstanceOf[HttpURLConnection] + connection2.setRequestMethod("POST") + connection2.setRequestProperty("Content-Type", "application/json") + connection2.setRequestProperty("Auth", "wrongsecret") + connection2.setDoOutput(true) + val out2 = new OutputStreamWriter(connection2.getOutputStream) + out2.write(payload) + out2.close() + + responseCode = connection2.getResponseCode + assert(responseCode == 401, s"Expected 401 but got $responseCode") + + // 3. Test with valid auth header + val connection3 = url.openConnection().asInstanceOf[HttpURLConnection] + connection3.setRequestMethod("POST") + connection3.setRequestProperty("Content-Type", "application/json") + connection3.setRequestProperty("Auth", "secret123") + connection3.setDoOutput(true) + val out3 = new OutputStreamWriter(connection3.getOutputStream) + out3.write(payload) + out3.close() + + responseCode = connection3.getResponseCode + assert(responseCode == 200, s"Expected 200 but got $responseCode") + } finally { + // Do not stop to avoid H2 shared database closure + // authFramework.stop() + } + } +} diff --git a/src/test/scala/lc/ServerSpec.scala b/src/test/scala/lc/ServerSpec.scala index f87bf0f..2d2b2ee 100644 --- a/src/test/scala/lc/ServerSpec.scala +++ b/src/test/scala/lc/ServerSpec.scala @@ -5,6 +5,7 @@ import org.scalatest.BeforeAndAfterAll import java.net.{HttpURLConnection, URL} import java.io.{BufferedReader, InputStreamReader, OutputStreamWriter} import lc.LCFramework +import scala.jdk.CollectionConverters._ class ServerSpec extends AnyFunSuite with BeforeAndAfterAll { @@ -12,12 +13,16 @@ class ServerSpec extends AnyFunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { framework.start("tests/debug-config.json") + // Give the server a moment to start and generate some captchas Thread.sleep(2000) } override def afterAll(): Unit = { - framework.stop() + // Cannot safely stop the framework because the single underlying H2 database connection + // is closed when shutting down the framework, causing other tests to fail in parallel + // or sequential runs inside the same forked JVM. + // framework.stop() } test("Server should respond with an id for a valid captcha request") { diff --git a/tests/auth-config.json b/tests/auth-config.json new file mode 100644 index 0000000..2b04e06 --- /dev/null +++ b/tests/auth-config.json @@ -0,0 +1,20 @@ +{ + "randomSeed" : 20, + "port" : 8889, + "address" : "0.0.0.0", + "captchaExpiryTimeLimit" : 5, + "bufferCount" : 10, + "threadDelay" : 2, + "playgroundEnabled" : false, + "authRequired" : true, + "corsHeader" : "*", + "maxAttemptsRatio" : 0.01, + "captchas" : [ { + "name" : "DebugCaptcha", + "allowedLevels" : [ "debug" ], + "allowedMedia" : [ "image/png" ], + "allowedInputType" : [ "text" ], + "allowedSizes" : [ "350x100" ], + "config" : {} + }] +}