Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ ThisBuild / assemblyMergeStrategy := {
}

run / fork := true

Test / fork := true
Test / forkOptions := ForkOptions().withRunJVMOptions(Vector("-Djava.awt.headless=true"))
4 changes: 4 additions & 0 deletions src/main/java/org/limium/picoserve/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ public String getMethod() {
return exchange.getRequestMethod();
}

public Map<String, List<String>> getHeaders() {
return exchange.getRequestHeaders();
}

public Map<String, List<String>> getQueryParams() {
final var query = exchange.getRequestURI().getQuery();
final var params = parseParams(query);
Expand Down
11 changes: 9 additions & 2 deletions src/main/scala/lc/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ import lc.server.Server
import lc.background.BackgroundTask
import lc.database.Statements

class LCFramework {
class LCFramework(authKey: Option[String] = sys.env.get("AUTH_KEY")) {
private var backgroundTask: Option[BackgroundTask] = None
private var server: Option[Server] = None

def start(configFilePath: String = "data/config.json"): Unit = {
val config = new Config(configFilePath)

if (config.authRequired && authKey.isEmpty) {
throw new Exception("AUTH_KEY environment variable is not specified, but authRequired is true.")
}

Statements.maxAttempts = config.maxAttempts
val captchaProviders = new CaptchaProviders(config = config)
val captchaManager = new CaptchaManager(config = config, captchaProviders = captchaProviders)
Expand All @@ -23,7 +28,9 @@ class LCFramework {
port = config.port,
captchaManager = captchaManager,
playgroundEnabled = config.playgroundEnabled,
corsHeader = config.corsHeader
corsHeader = config.corsHeader,
authRequired = config.authRequired,
authKey = authKey
)
srv.start()
server = Some(srv)
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/lc/core/config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class Config(configFilePath: String) {
val playgroundEnabled: Boolean = configFields.playgroundEnabledBool.getOrElse(true)
val corsHeader: String = configFields.corsHeader.getOrElse("")
val maxAttempts: Int = Math.max(1, (configFields.maxAttemptsRatioFloat.getOrElse(0.01f) * bufferCount).toInt)
val authRequired: Boolean = configFields.authRequiredBool.getOrElse(false)

val captchaConfig: List[CaptchaConfig] = appConfig.captchas
val allowedLevels: Set[String] = captchaConfig.flatMap(_.allowedLevels).toSet
Expand All @@ -70,6 +71,7 @@ class Config(configFilePath: String) {
playgroundEnabled = Some(true),
corsHeader = Some(""),
maxAttemptsRatio = Some(0.01f),
authRequired = Some(false),
captchas = List(
CaptchaConfig(
name = "FilterChallenge",
Expand Down
7 changes: 5 additions & 2 deletions src/main/scala/lc/core/models.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ case class AppConfig(
playgroundEnabled: Option[Boolean] = None,
corsHeader: Option[String] = None,
maxAttemptsRatio: Option[Float] = None,
authRequired: Option[Boolean] = None,
captchas: List[CaptchaConfig] = List.empty
) {
def toConfigField: ConfigField = ConfigField(
port, address, bufferCount, seed, captchaExpiryTimeLimit,
threadDelay, playgroundEnabled, corsHeader, maxAttemptsRatio
threadDelay, playgroundEnabled, corsHeader, maxAttemptsRatio, authRequired
)
}
object AppConfig {
Expand All @@ -115,7 +116,8 @@ case class ConfigField(
threadDelay: Option[Int] = None,
playgroundEnabled: Option[Boolean] = None,
corsHeader: Option[String] = None,
maxAttemptsRatio: Option[Float] = None
maxAttemptsRatio: Option[Float] = None,
authRequired: Option[Boolean] = None
) {
lazy val portInt: Option[Int] = port
lazy val bufferCountInt: Option[Int] = bufferCount
Expand All @@ -124,4 +126,5 @@ case class ConfigField(
lazy val threadDelayInt: Option[Int] = threadDelay
lazy val maxAttemptsRatioFloat: Option[Float] = maxAttemptsRatio
lazy val playgroundEnabledBool: Option[Boolean] = playgroundEnabled.map(_ || false)
lazy val authRequiredBool: Option[Boolean] = authRequired.map(_ || false)
}
2 changes: 1 addition & 1 deletion src/main/scala/lc/database/DB.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import java.sql.{Connection, DriverManager, Statement}

class DBConn() {
val con: Connection =
DriverManager.getConnection("jdbc:h2:./data/H2/captcha3;MAX_COMPACT_TIME=8000;DB_CLOSE_ON_EXIT=FALSE", "sa", "")
DriverManager.getConnection("jdbc:h2:./data/H2/captcha3;MAX_COMPACT_TIME=8000;DB_CLOSE_ON_EXIT=FALSE;DB_CLOSE_DELAY=-1", "sa", "")

def getStatement(): Statement = {
con.createStatement()
Expand Down
77 changes: 53 additions & 24 deletions src/main/scala/lc/server/Server.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,55 +18,84 @@ class Server(
port: Int,
captchaManager: CaptchaManager,
playgroundEnabled: Boolean,
corsHeader: String
corsHeader: String,
authRequired: Boolean = false,
authKey: Option[String] = None
) {
var headerMap: util.Map[String, util.List[String]] = null
if (corsHeader.nonEmpty) {
headerMap = Map("Access-Control-Allow-Origin" -> List(corsHeader).asJava).asJava
}

private def checkAuth(request: picoserve.Server#Request): Boolean = {
if (!authRequired) return true
val headers = request.getHeaders()
if (headers != null && headers.containsKey("Auth")) {
val authHeaderValues = headers.get("Auth")
if (authHeaderValues != null && authHeaderValues.size() > 0) {
val authHeader = authHeaderValues.get(0)
val expectedKey = authKey.getOrElse("")
return authHeader == expectedKey
}
}
false
}

val serverBuilder: ServerBuilder = picoserve.Server
.builder()
.address(new InetSocketAddress(address, port))
.backlog(32)
.POST(
"/v2/captcha",
(request) => {
val bodyStr = request.getBodyString().trim.replaceAll("\u0000", "")
val paramEither = Parameters.codec.decode(ByteBuffer.wrap(bodyStr.getBytes("UTF-8")))
paramEither match {
case Right(param) =>
val id = captchaManager.getChallenge(param)
getResponse(id, headerMap)
case Left(err) =>
getResponse(Left(Error("Invalid parameters: " + err.toString)), headerMap)
if (!checkAuth(request)) {
new StringResponse(401, "Unauthorized", headerMap)
} else {
val bodyStr = request.getBodyString().trim.replaceAll("\u0000", "")
val paramEither = Parameters.codec.decode(ByteBuffer.wrap(bodyStr.getBytes("UTF-8")))
paramEither match {
case Right(param) =>
val id = captchaManager.getChallenge(param)
getResponse(id, headerMap)
case Left(err) =>
getResponse(Left(Error("Invalid parameters: " + err.toString)), headerMap)
}
}
}
)
.GET(
"/v2/media",
(request) => {
val params = request.getQueryParams()
val result = if (params.containsKey("id")) {
val paramId = params.get("id").get(0)
val id = Id(paramId)
captchaManager.getCaptcha(id)
if (!checkAuth(request)) {
new StringResponse(401, "Unauthorized", headerMap)
} else {
Left(Error(ErrorMessageEnum.INVALID_PARAM.toString + "=> id"))
val params = request.getQueryParams()
val result = if (params.containsKey("id")) {
val paramId = params.get("id").get(0)
val id = Id(paramId)
captchaManager.getCaptcha(id)
} else {
Left(Error(ErrorMessageEnum.INVALID_PARAM.toString + "=> id"))
}
getResponse(result, headerMap)
}
getResponse(result, headerMap)
}
)
.POST(
"/v2/answer",
(request) => {
val bodyStr = request.getBodyString().trim.replaceAll("\u0000", "")
val answerEither = Answer.codec.decode(ByteBuffer.wrap(bodyStr.getBytes("UTF-8")))
answerEither match {
case Right(answer) =>
val result = captchaManager.checkAnswer(answer)
getResponse(result, headerMap)
case Left(err) =>
getResponse(Left(Error("Invalid answer format: " + err.toString)), headerMap)
if (!checkAuth(request)) {
new StringResponse(401, "Unauthorized", headerMap)
} else {
val bodyStr = request.getBodyString().trim.replaceAll("\u0000", "")
val answerEither = Answer.codec.decode(ByteBuffer.wrap(bodyStr.getBytes("UTF-8")))
answerEither match {
case Right(answer) =>
val result = captchaManager.checkAnswer(answer)
getResponse(result, headerMap)
case Left(err) =>
getResponse(Left(Error("Invalid answer format: " + err.toString)), headerMap)
}
}
}
)
Expand Down
64 changes: 64 additions & 0 deletions src/test/scala/lc/ServerAuthSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package lc.server

import org.scalatest.funsuite.AnyFunSuite
import java.net.{HttpURLConnection, URL}
import java.io.{BufferedReader, InputStreamReader, OutputStreamWriter}
import lc.LCFramework
import scala.jdk.CollectionConverters._

class ServerAuthSpec extends AnyFunSuite {

test("Server should require auth header when authRequired is true") {
val authFramework = new LCFramework(authKey = Some("secret123"))
// Ensure DB is not concurrently accessed by running tests sequentially
// The previous failure was due to parallel test execution and the embedded H2 DB getting closed.
authFramework.start("tests/auth-config.json")
Thread.sleep(2000)

try {
val url = new URL("http://localhost:8889/v2/captcha")

// 1. Test without auth header
val connection1 = url.openConnection().asInstanceOf[HttpURLConnection]
connection1.setRequestMethod("POST")
connection1.setRequestProperty("Content-Type", "application/json")
connection1.setDoOutput(true)
val payload = """{"level":"debug","media":"image/png","input_type":"text","size":"350x100"}"""
val out1 = new OutputStreamWriter(connection1.getOutputStream)
out1.write(payload)
out1.close()

var responseCode = connection1.getResponseCode
assert(responseCode == 401, s"Expected 401 but got $responseCode")

// 2. Test with invalid auth header
val connection2 = url.openConnection().asInstanceOf[HttpURLConnection]
connection2.setRequestMethod("POST")
connection2.setRequestProperty("Content-Type", "application/json")
connection2.setRequestProperty("Auth", "wrongsecret")
connection2.setDoOutput(true)
val out2 = new OutputStreamWriter(connection2.getOutputStream)
out2.write(payload)
out2.close()

responseCode = connection2.getResponseCode
assert(responseCode == 401, s"Expected 401 but got $responseCode")

// 3. Test with valid auth header
val connection3 = url.openConnection().asInstanceOf[HttpURLConnection]
connection3.setRequestMethod("POST")
connection3.setRequestProperty("Content-Type", "application/json")
connection3.setRequestProperty("Auth", "secret123")
connection3.setDoOutput(true)
val out3 = new OutputStreamWriter(connection3.getOutputStream)
out3.write(payload)
out3.close()

responseCode = connection3.getResponseCode
assert(responseCode == 200, s"Expected 200 but got $responseCode")
} finally {
// Do not stop to avoid H2 shared database closure
// authFramework.stop()
}
}
}
7 changes: 6 additions & 1 deletion src/test/scala/lc/ServerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,24 @@ import org.scalatest.BeforeAndAfterAll
import java.net.{HttpURLConnection, URL}
import java.io.{BufferedReader, InputStreamReader, OutputStreamWriter}
import lc.LCFramework
import scala.jdk.CollectionConverters._

class ServerSpec extends AnyFunSuite with BeforeAndAfterAll {

val framework = new LCFramework()

override def beforeAll(): Unit = {
framework.start("tests/debug-config.json")

// Give the server a moment to start and generate some captchas
Thread.sleep(2000)
}

override def afterAll(): Unit = {
framework.stop()
// Cannot safely stop the framework because the single underlying H2 database connection
// is closed when shutting down the framework, causing other tests to fail in parallel
// or sequential runs inside the same forked JVM.
// framework.stop()
}

test("Server should respond with an id for a valid captcha request") {
Expand Down
20 changes: 20 additions & 0 deletions tests/auth-config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"randomSeed" : 20,
"port" : 8889,
"address" : "0.0.0.0",
"captchaExpiryTimeLimit" : 5,
"bufferCount" : 10,
"threadDelay" : 2,
"playgroundEnabled" : false,
"authRequired" : true,
"corsHeader" : "*",
"maxAttemptsRatio" : 0.01,
"captchas" : [ {
"name" : "DebugCaptcha",
"allowedLevels" : [ "debug" ],
"allowedMedia" : [ "image/png" ],
"allowedInputType" : [ "text" ],
"allowedSizes" : [ "350x100" ],
"config" : {}
}]
}
Loading