diff --git a/.gitignore b/.gitignore index 0222c5bf..620f0f48 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ # various directories target/ +.bloop/ +.bsp/ .idea/ +.metals/ project/target project/project/target/ stuff @@ -14,6 +17,7 @@ include-*/ # hidden files *.~ +.DS_Store #tools *.bat diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 00000000..ba14fb70 --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1 @@ +version = "2.6.4" diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..33d99e8d --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,18 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "scala", + "name": "Debug", + "request": "launch", + "mainClass": "millfork.Main", + // optional jvm properties to use + "jvmOptions": [], + "args": [] + }, + + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..e72490fb --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "files.watcherExclude": { + "**/target": true + } +} \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 00000000..c523fc58 --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,17 @@ +{ + // See https://go.microsoft.com/fwlink/?LinkId=733558 + // for the documentation about the tasks.json format + "version": "2.0.0", + "tasks": [ + { + "label": "Compile Millfork", + "type": "shell", + "command": "sbt -DskipTests=true compile && sbt -DskipTests=true assembly", + "problemMatcher": [], + "group": { + "kind": "build", + "isDefault": true + } + } + ] +} \ No newline at end of file diff --git a/build.sbt b/build.sbt index 29a866f9..99da51cd 100644 --- a/build.sbt +++ b/build.sbt @@ -10,6 +10,10 @@ libraryDependencies += "com.lihaoyi" %% "fastparse" % "1.0.0" libraryDependencies += "org.apache.commons" % "commons-configuration2" % "2.2" +libraryDependencies += "org.eclipse.lsp4j" % "org.eclipse.lsp4j" % "0.9.0" + +libraryDependencies += "net.liftweb" %% "lift-json" % "3.4.2" + libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.8" % "test" val testDependencies = Seq( @@ -33,7 +37,7 @@ val testDependencies = Seq( val includesTests = System.getProperty("skipTests") == null -libraryDependencies ++=( +libraryDependencies ++= ( if (includesTests) { println("Including test dependencies") testDependencies @@ -43,26 +47,27 @@ libraryDependencies ++=( ) (if (!includesTests) { - // Disable assembling tests - sbt.internals.DslEntry.fromSettingsDef(test in assembly := {}) -} else { - sbt.internals.DslEntry.fromSettingsDef(Seq[sbt.Def.Setting[_]]()) -}) + // Disable assembling tests + sbt.internal.DslEntry.fromSettingsDef(test in assembly := {}) + } else { + sbt.internal.DslEntry.fromSettingsDef(Seq[sbt.Def.Setting[_]]()) + }) mainClass in Compile := Some("millfork.Main") assemblyJarName := "millfork.jar" -lazy val root = (project in file(".")). - enablePlugins(BuildInfoPlugin). - settings( +lazy val root = (project in file(".")) + .enablePlugins(BuildInfoPlugin) + .settings( buildInfoKeys := Seq[BuildInfoKey](name, version, scalaVersion, sbtVersion), buildInfoPackage := "millfork.buildinfo" ) import sbtassembly.AssemblyKeys -val releaseDist = TaskKey[File]("release-dist", "Creates a distributable zip file.") +val releaseDist = + TaskKey[File]("release-dist", "Creates a distributable zip file.") releaseDist := { val jar = AssemblyKeys.assembly.value @@ -79,7 +84,10 @@ releaseDist := { IO.createDirectory(distDir) IO.copyFile(jar, distDir / jar.name) IO.copyFile(base / "LICENSE", distDir / "LICENSE") - IO.copyFile(base / "src/3rd-party-licenses.txt", distDir / "3rd-party-licenses.txt") + IO.copyFile( + base / "src/3rd-party-licenses.txt", + distDir / "3rd-party-licenses.txt" + ) IO.copyFile(base / "CHANGELOG.md", distDir / "CHANGELOG.md") IO.copyFile(base / "README.md", distDir / "README.md") IO.copyFile(base / "COMPILING.md", distDir / "COMPILING.md") @@ -89,8 +97,14 @@ releaseDist := { } copyDir("include") copyDir("docs") - def entries(f: File): List[File] = f :: (if (f.isDirectory) IO.listFiles(f).toList.flatMap(entries) else Nil) - IO.zip(entries(distDir).map(d => (d, d.getAbsolutePath.substring(distDir.getParent.length + 1))), zipFile) + def entries(f: File): List[File] = + f :: (if (f.isDirectory) IO.listFiles(f).toList.flatMap(entries) else Nil) + IO.zip( + entries(distDir).map(d => + (d, d.getAbsolutePath.substring(distDir.getParent.length + 1)) + ), + zipFile + ) IO.delete(distDir) zipFile } diff --git a/project/build.properties b/project/build.properties index 5a1071cb..2ec6cd91 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version = 0.13.18 +sbt.version = 1.4.0 diff --git a/project/metals.sbt b/project/metals.sbt new file mode 100644 index 00000000..aa04ddb5 --- /dev/null +++ b/project/metals.sbt @@ -0,0 +1,4 @@ +// DO NOT EDIT! This file is auto-generated. +// This file enables sbt-bloop to create bloop config files. + +addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.4.4-13-408f4d80") diff --git a/project/project/metals.sbt b/project/project/metals.sbt new file mode 100644 index 00000000..aa04ddb5 --- /dev/null +++ b/project/project/metals.sbt @@ -0,0 +1,4 @@ +// DO NOT EDIT! This file is auto-generated. +// This file enables sbt-bloop to create bloop config files. + +addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.4.4-13-408f4d80") diff --git a/project/project/project/metals.sbt b/project/project/project/metals.sbt new file mode 100644 index 00000000..aa04ddb5 --- /dev/null +++ b/project/project/project/metals.sbt @@ -0,0 +1,4 @@ +// DO NOT EDIT! This file is auto-generated. +// This file enables sbt-bloop to create bloop config files. + +addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.4.4-13-408f4d80") diff --git a/src/main/scala/millfork/Context.scala b/src/main/scala/millfork/Context.scala index 18a637f6..ac672d5e 100644 --- a/src/main/scala/millfork/Context.scala +++ b/src/main/scala/millfork/Context.scala @@ -9,6 +9,7 @@ import millfork.error.Logger case class Context(errorReporting: Logger, inputFileNames: List[String], outputFileName: Option[String] = None, + configFilePath: Option[String] = None, runFileName: Option[String] = None, runParams: Seq[String] = Vector(), optimizationLevel: Option[Int] = None, @@ -21,7 +22,8 @@ case class Context(errorReporting: Logger, extraIncludePath: Seq[String] = IndexedSeq(), flags: Map[CompilationFlag.Value, Boolean] = Map(), features: Map[String, Long] = Map(), - verbosity: Option[Int] = None) { + verbosity: Option[Int] = None, + languageServer: Boolean = false) { def changeFlag(f: CompilationFlag.Value, b: Boolean): Context = { if (flags.contains(f)) { if (flags(f) != b) { diff --git a/src/main/scala/millfork/Main.scala b/src/main/scala/millfork/Main.scala index bf2c0017..c946b98b 100644 --- a/src/main/scala/millfork/Main.scala +++ b/src/main/scala/millfork/Main.scala @@ -19,11 +19,15 @@ import millfork.node.StandardCallGraph import millfork.output._ import millfork.parser.{MSourceLoadingQueue, MosSourceLoadingQueue, TextCodecRepository, ZSourceLoadingQueue} +import millfork.language.{MfLanguageServer,MfLanguageClient,LanguageServerLogger} +import org.eclipse.lsp4j.services.LanguageServer +import org.eclipse.lsp4j.jsonrpc.Launcher +import java.util.concurrent.Executors +import java.io.PrintWriter +import millfork.cli.JsonConfigParser object Main { - - def main(args: Array[String]): Unit = { val errorReporting = new ConsoleLogger implicit val __implicitLogger: Logger = errorReporting @@ -34,6 +38,7 @@ object Main { val startTime = System.nanoTime() val (status, c0) = parser(errorReporting).parse(Context(errorReporting, Nil), args.toList) + val c1 = JsonConfigParser.parseConfig(c0, errorReporting) status match { case CliStatus.Quit => return case CliStatus.Failed => @@ -41,8 +46,8 @@ object Main { case CliStatus.Ok => () } errorReporting.assertNoErrors("Invalid command line") - errorReporting.verbosity = c0.verbosity.getOrElse(0) - if (c0.inputFileNames.isEmpty) { + errorReporting.verbosity = c1.verbosity.getOrElse(0) + if (c1.inputFileNames.isEmpty && !c1.languageServer) { errorReporting.fatalQuit("No input files") } @@ -51,14 +56,14 @@ object Main { errorReporting.trace("This program comes with ABSOLUTELY NO WARRANTY.") errorReporting.trace("This is free software, and you are welcome to redistribute it under certain conditions") errorReporting.trace("You should have received a copy of the GNU General Public License along with this program. If not, see https://www.gnu.org/licenses/") - val c = fixMissingIncludePath(c0).filloutFlags() + val c = fixMissingIncludePath(c1).filloutFlags() if (c.includePath.isEmpty) { errorReporting.warn("Failed to detect the default include directory, consider using the -I option") } val textCodecRepository = new TextCodecRepository("." :: c.includePath) val platform = Platform.lookupPlatformFile("." :: c.includePath, c.platform.getOrElse { - errorReporting.info("No platform selected, defaulting to `c64`") + if (!c1.languageServer) errorReporting.info("No platform selected, defaulting to `c64`") "c64" }, textCodecRepository) val options = CompilationOptions(platform, c.flags, c.outputFileName, c.zpRegisterSize.getOrElse(platform.zpRegisterSize), c.features, textCodecRepository, JobContext(errorReporting, new LabelGenerator)) @@ -67,6 +72,25 @@ object Main { case (f, b) => errorReporting.debug(f" $f%-30s : $b%s") } + if (c1.languageServer) { + // We cannot log anything to stdout when starting the language server (otherwise it's a protocol violation) + errorReporting.setOutput(true) + val server = new MfLanguageServer(c, options) + + val exec = Executors.newCachedThreadPool() + + val launcher = new Launcher.Builder[MfLanguageClient]() + .setExecutorService(exec) + .setInput(System.in) + .setOutput(System.out) + .setRemoteInterface(classOf[MfLanguageClient]) + .setLocalService(server) + .create() + val clientProxy = launcher.getRemoteProxy + server.client = Some(clientProxy) + launcher.startListening().get() + } + val output = c.outputFileName match { case Some(ofn) => ofn case None => c.inputFileNames match { @@ -252,7 +276,7 @@ object Main { val unoptimized = new MosSourceLoadingQueue( initialFilenames = c.inputFileNames, includePath = c.includePath, - options = options).run() + options = options).run().compilationOrderProgram val program = if (optLevel > 0) { OptimizationPresets.NodeOpt.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt, options)) @@ -306,7 +330,7 @@ object Main { val unoptimized = new ZSourceLoadingQueue( initialFilenames = c.inputFileNames, includePath = c.includePath, - options = options).run() + options = options).run().compilationOrderProgram val program = if (optLevel > 0) { OptimizationPresets.NodeOpt.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt, options)) @@ -346,7 +370,7 @@ object Main { val unoptimized = new MSourceLoadingQueue( initialFilenames = c.inputFileNames, includePath = c.includePath, - options = options).run() + options = options).run().compilationOrderProgram val program = if (optLevel > 0) { OptimizationPresets.NodeOpt.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt, options)) @@ -376,7 +400,7 @@ object Main { val unoptimized = new ZSourceLoadingQueue( initialFilenames = c.inputFileNames, includePath = c.includePath, - options = options).run() + options = options).run().compilationOrderProgram val program = if (optLevel > 0) { OptimizationPresets.NodeOpt.foldLeft(unoptimized)((p, opt) => p.applyNodeOptimization(opt, options)) @@ -429,6 +453,15 @@ object Main { c.copy(outputLabels = true, outputLabelsFormatOverride = Some(f)) }.description("Generate also the label file in the given format. Available options: vice, nesasm, sym.") + flag("-lsp").action { c => + c.copy(languageServer = true) + }.description("Start the Millfork language server. Does not start compilation.") + + parameter("-c", "--config").placeholder("").action { (p, c) => + assertNone(c.outputFileName, "Config file already defined") + c.copy(configFilePath = Some(p)) + }.description("The Millfork config file. Suppliments the provided CLI options.") + boolean("-fbreakpoints", "-fno-breakpoints").action((c,v) => c.changeFlag(CompilationFlag.EnableBreakpoints, v) ).description("Include breakpoints in the label file. Requires either -g or -G.") diff --git a/src/main/scala/millfork/cli/JsonConfigParser.scala b/src/main/scala/millfork/cli/JsonConfigParser.scala new file mode 100644 index 00000000..da89206e --- /dev/null +++ b/src/main/scala/millfork/cli/JsonConfigParser.scala @@ -0,0 +1,65 @@ +package millfork.cli + +import net.liftweb.json._ +import java.nio.file.Files +import java.nio.file.Paths +import java.nio.charset.StandardCharsets +import scala.collection.mutable +import scala.collection.convert.ImplicitConversionsToScala._ +import java.io.InputStreamReader +import millfork.Context +import millfork.error.ConsoleLogger + +case class JsonConfig( + include: Option[List[String]], + platform: Option[String], + inputFiles: Option[List[String]] +) + +object JsonConfigParser { + implicit val formats = DefaultFormats + + def parseConfig(context: Context, logger: ConsoleLogger): Context = { + var newContext = context + + var defaultConfig = false + val filePath = context.configFilePath.getOrElse({ + defaultConfig = true + ".millforkrc.json" + }) + + val path = Paths.get(filePath) + + try { + val jsonString = + Files + .readAllLines(path, StandardCharsets.UTF_8) + .toIndexedSeq + .mkString("") + + val result = parse(jsonString).extract[JsonConfig] + + if (context.inputFileNames.length < 1 && result.inputFiles.isDefined) { + newContext = newContext.copy(inputFileNames = result.inputFiles.get) + } + + if (context.includePath.length < 1 && result.include.isDefined) { + newContext = + newContext.copy(extraIncludePath = result.include.get.toSeq) + } + + if (context.platform.isEmpty && result.platform.isDefined) { + newContext = newContext.copy(platform = Some(result.platform.get)) + } + } catch { + case default: Throwable => { + if (!defaultConfig) { + // Only throw error if not default config + logger.fatalQuit("Invalid config file") + } + } + } + + newContext + } +} diff --git a/src/main/scala/millfork/error/ConsoleLogger.scala b/src/main/scala/millfork/error/ConsoleLogger.scala index 55baa889..296c6787 100644 --- a/src/main/scala/millfork/error/ConsoleLogger.scala +++ b/src/main/scala/millfork/error/ConsoleLogger.scala @@ -4,10 +4,12 @@ import millfork.assembly.SourceLine import millfork.node.Position import scala.collection.mutable +import java.io.PrintStream class ConsoleLogger extends Logger { FatalErrorReporting.considerAsGlobal(this) + private var defaultUseStderr = false var verbosity = 0 var fatalWarnings = false @@ -15,6 +17,10 @@ class ConsoleLogger extends Logger { this.fatalWarnings = fatalWarnings } + def setOutput(useStderr: Boolean): Unit = { + this.defaultUseStderr = useStderr + } + var hasErrors = false private val sourceLines: mutable.Map[String, IndexedSeq[String]] = mutable.Map() @@ -27,11 +33,11 @@ class ConsoleLogger extends Logger { val line = lines.apply(lineIx) val column = pos.get.column - 1 val margin = " " - print(margin) - println(line) - print(margin) - print(" " * column) - println("^") + this.print(margin) + this.println(line) + this.print(margin) + this.print(" " * column) + this.println("^") } } } @@ -42,14 +48,14 @@ class ConsoleLogger extends Logger { override def info(msg: String, position: Option[Position] = None): Unit = { if (verbosity < 0) return - println("INFO: " + f(position) + msg) + this.println("INFO: " + f(position) + msg) printErrorContext(position) flushOutput() } override def debug(msg: String, position: Option[Position] = None): Unit = { if (verbosity < 1) return - println("DEBUG: " + f(position) + msg) + this.println("DEBUG: " + f(position) + msg) flushOutput() } @@ -59,7 +65,7 @@ class ConsoleLogger extends Logger { override def trace(msg: String, position: Option[Position] = None): Unit = { if (verbosity < 2) return - println("TRACE: " + f(position) + msg) + this.println("TRACE: " + f(position) + msg) flushOutput() } @@ -71,7 +77,7 @@ class ConsoleLogger extends Logger { override def warn(msg: String, position: Option[Position] = None): Unit = { if (verbosity < 0) return - println("WARN: " + f(position) + msg) + this.println("WARN: " + f(position) + msg) printErrorContext(position) flushOutput() if (fatalWarnings) { @@ -81,14 +87,14 @@ class ConsoleLogger extends Logger { override def error(msg: String, position: Option[Position] = None): Unit = { hasErrors = true - println("ERROR: " + f(position) + msg) + this.println("ERROR: " + f(position) + msg) printErrorContext(position) flushOutput() } override def fatal(msg: String, position: Option[Position] = None): Nothing = { hasErrors = true - println("FATAL: " + f(position) + msg) + this.println("FATAL: " + f(position) + msg) printErrorContext(position) flushOutput() throw new AssertionError(msg) @@ -96,7 +102,7 @@ class ConsoleLogger extends Logger { override def fatalQuit(msg: String, position: Option[Position] = None): Nothing = { hasErrors = true - println("FATAL: " + f(position) + msg) + this.println("FATAL: " + f(position) + msg) printErrorContext(position) flushOutput() System.exit(1) @@ -128,4 +134,10 @@ class ConsoleLogger extends Logger { file <- sourceLines.get(line.moduleName) line <- file.lift(line.line - 1) } yield line + + private def getOutputStream: PrintStream = if (this.defaultUseStderr) System.err else System.out + + private def print(x: String): Unit = getOutputStream.print(x) + + private def println(x: String): Unit = getOutputStream.println(x) } \ No newline at end of file diff --git a/src/main/scala/millfork/language/LanguageServerLogger.scala b/src/main/scala/millfork/language/LanguageServerLogger.scala new file mode 100644 index 00000000..918ac299 --- /dev/null +++ b/src/main/scala/millfork/language/LanguageServerLogger.scala @@ -0,0 +1,35 @@ +package millfork.language + +import millfork.error.Logger +import millfork.node.Position +import millfork.assembly.SourceLine + +class LanguageServerLogger extends Logger { + // TODO: Unused. Complete stub to send diagnostics to client + override def setFatalWarnings(fatalWarnings: Boolean): Unit = {} + + override def info(msg: String, position: Option[Position]): Unit = {} + + override def debug(msg: String, position: Option[Position]): Unit = {} + + override def trace(msg: String, position: Option[Position]): Unit = {} + + override def traceEnabled: Boolean = false + + override def debugEnabled: Boolean = false + + override def warn(msg: String, position: Option[Position]): Unit = {} + + override def error(msg: String, position: Option[Position]): Unit = {} + + override def fatal(msg: String, position: Option[Position]): Nothing = ??? + + override def fatalQuit(msg: String, position: Option[Position]): Nothing = ??? + + override def assertNoErrors(msg: String): Unit = {} + + override def addSource(filename: String, lines: IndexedSeq[String]): Unit = {} + + override def getLine(line: SourceLine): Option[String] = None + +} diff --git a/src/main/scala/millfork/language/MfLanguageClient.scala b/src/main/scala/millfork/language/MfLanguageClient.scala new file mode 100644 index 00000000..2f27749e --- /dev/null +++ b/src/main/scala/millfork/language/MfLanguageClient.scala @@ -0,0 +1,69 @@ +package millfork.language + +import org.eclipse.lsp4j.services.LanguageClient +import org.eclipse.lsp4j.jsonrpc.services.JsonNotification +import org.eclipse.lsp4j.MessageType +import org.eclipse.lsp4j.MessageParams + +trait MfLanguageClient extends LanguageClient { + + /** + * Display message in the editor "status bar", which should be displayed somewhere alongside the buffer. + * + * The status bar should always be visible to the user. + * + * - VS Code: https://code.visualstudio.com/docs/extensionAPI/vscode-api#StatusBarItem + */ + // @JsonNotification("metals/status") + // def metalsStatus(params: MetalsStatusParams): Unit + + /** + * Starts a long running task with no estimate for how long it will take to complete. + * + * - request cancellation from the server indicates that the task has completed + * - response with cancel=true indicates the client wishes to cancel the slow task + */ + // @JsonRequest("metals/slowTask") + // def metalsSlowTask( + // params: MetalsSlowTaskParams + // ): CompletableFuture[MetalsSlowTaskResult] + + // @JsonNotification("metals/executeClientCommand") + // def metalsExecuteClientCommand(params: ExecuteCommandParams): Unit + + final def refreshModel(): Unit = { + // val command = ClientCommands.RefreshModel.id + // val params = new ExecuteCommandParams(command, Nil.asJava) + // metalsExecuteClientCommand(params) + } + + /** + * Opens an input box to ask the user for input. + * + * @return the user provided input. The future can be cancelled, meaning + * the input box should be dismissed in the editor. + */ + // @JsonRequest("metals/inputBox") + // def metalsInputBox( + // params: MetalsInputBoxParams + // ): CompletableFuture[MetalsInputBoxResult] + + /** + * Opens an menu to ask the user to pick one of the suggested options. + * + * @return the user provided pick. The future can be cancelled, meaning + * the input box should be dismissed in the editor. + */ + // @JsonRequest("metals/quickPick") + // def metalsQuickPick( + // params: MetalsQuickPickParams + // ): CompletableFuture[MetalsQuickPickResult] + + final def showMessage(messageType: MessageType, message: String): Unit = { + val params = new MessageParams(messageType, message) + showMessage(params) + } + + def shutdown(): Unit = {} + +} diff --git a/src/main/scala/millfork/language/MfLanguageServer.scala b/src/main/scala/millfork/language/MfLanguageServer.scala new file mode 100644 index 00000000..e2199d32 --- /dev/null +++ b/src/main/scala/millfork/language/MfLanguageServer.scala @@ -0,0 +1,483 @@ +package millfork.language +import millfork.CompilationOptions +import millfork.parser.{ + MosSourceLoadingQueue, + ZSourceLoadingQueue, + MSourceLoadingQueue, + ParsedProgram +} +import millfork.Context + +import millfork.node.{ + FunctionDeclarationStatement, + ParameterDeclaration, + Position, + Node +} + +import org.eclipse.lsp4j.services.{ + LanguageServer, + TextDocumentService, + WorkspaceService +} +import org.eclipse.lsp4j.{ + InitializeParams, + InitializeResult, + ServerCapabilities, + Range +} +import org.eclipse.lsp4j.jsonrpc.services.JsonRequest + +import java.util.concurrent.CompletableFuture +import org.eclipse.lsp4j.TextDocumentPositionParams +import org.eclipse.lsp4j.Hover +import org.eclipse.lsp4j.jsonrpc.messages.Either +import java.{util => ju} +import scala.collection.mutable +import org.eclipse.lsp4j.MarkedString +import org.eclipse.lsp4j.jsonrpc.services.JsonNotification +import org.eclipse.lsp4j.InitializedParams +import org.eclipse.lsp4j.MarkupContent +import org.eclipse.lsp4j.DefinitionParams +import org.eclipse.lsp4j.Location +import net.liftweb.json._ +import net.liftweb.json.Serialization.{read, write} +import org.eclipse.lsp4j.MessageParams +import org.eclipse.lsp4j.MessageType +import org.eclipse.lsp4j.DidOpenTextDocumentParams +import org.eclipse.lsp4j.TextDocumentSyncKind +import org.eclipse.lsp4j.DidChangeTextDocumentParams +import java.nio.file.Path +import java.nio.file.Paths +import org.eclipse.lsp4j.VersionedTextDocumentIdentifier +import millfork.node.Program +import millfork.node.ImportStatement +import org.eclipse.lsp4j.ReferenceParams +import millfork.node.DeclarationStatement +import scala.collection.JavaConverters._ +import millfork.CpuFamily +import millfork.parser.ZSourceLoadingQueue +import millfork.parser.MSourceLoadingQueue + +class MfLanguageServer(context: Context, options: CompilationOptions) { + var client: Option[MfLanguageClient] = None + + val cachedModules: mutable.Map[String, Program] = mutable.Map() + private var cachedProgram: Option[ParsedProgram] = None + private val moduleNames: mutable.Map[String, String] = mutable.Map() + private val modulePaths: mutable.Map[String, Path] = mutable.Map() + + @JsonRequest("initialize") + def initialize( + params: InitializeParams + ): CompletableFuture[ + InitializeResult + ] = + CompletableFuture.completedFuture { + val capabilities = new ServerCapabilities() + capabilities.setHoverProvider(true) + capabilities.setDefinitionProvider(true) + capabilities.setTextDocumentSync(TextDocumentSyncKind.Full) + capabilities.setReferencesProvider(true) + + new InitializeResult(capabilities) + } + + @JsonNotification("initialized") + def initialized(params: InitializedParams): CompletableFuture[Unit] = + CompletableFuture.completedFuture { + populateProgramForPath() + } + + // @JsonRequest("getTextDocumentService") + // def getTextDocumentService(): CompletableFuture[TextDocumentService] = { + // val completableFuture = new CompletableFuture[InitializeResult]() + // completableFuture.complete(new TextDocumentService()) + // completableFuture + // } + + // @JsonRequest("getWorkspaceService") + // def getWorkspaceService(): CompletableFuture[WorkspaceService] = ??? + + @JsonRequest("exit") + def exit(): CompletableFuture[Unit] = ??? + + @JsonRequest("shutdown") + def shutdown(): CompletableFuture[Object] = ??? + + @JsonRequest("textDocument/didOpen") + def textDocumentDidOpen( + params: DidOpenTextDocumentParams + ): CompletableFuture[Unit] = + CompletableFuture.completedFuture { + val textDocument = params.getTextDocument() + val pathString = trimDocumentUri(textDocument.getUri()) + + val documentText = textDocument.getText().split("\n").toSeq + + rebuildASTForFile(pathString, documentText) + } + + @JsonRequest("textDocument/didChange") + def textDocumentDidChange( + params: DidChangeTextDocumentParams + ): CompletableFuture[Unit] = + CompletableFuture.completedFuture { + val pathString = trimDocumentUri(params.getTextDocument().getUri()) + + val documentText = + params.getContentChanges().get(0).getText().split("\n").toSeq + + rebuildASTForFile(pathString, documentText) + } + + def rebuildASTForFile(pathString: String, text: Seq[String]) = { + logEvent(TelemetryEvent("Rebuilding AST for module at path", pathString)) + + val path = Paths.get(pathString) + val moduleName = queue.extractName(pathString) + + logEvent( + TelemetryEvent( + "Path", + Map("path" -> path.toString(), "module" -> moduleName) + ) + ) + + val newProgram = queue.parseModuleWithLines( + moduleName, + path, + text, + context.includePath, + Left(None), + Nil + ) + + if (newProgram.isDefined) { + cachedModules.put(moduleName, newProgram.get) + + moduleNames.put(pathString, moduleName) + modulePaths.put(moduleName, Paths.get(pathString)) + + logEvent( + TelemetryEvent( + "Finished rebuilding AST for module at path", + pathString + ) + ) + } else { + logEvent( + TelemetryEvent("Failed to rebuild AST for module at path", pathString) + ) + } + } + + @JsonRequest("textDocument/definition") + def textDocumentDefinition( + params: DefinitionParams + ): CompletableFuture[Location] = + CompletableFuture.completedFuture { + val activePosition = params.getPosition() + + val statement = findExpressionAtPosition( + trimDocumentUri(params.getTextDocument().getUri()), + Position( + "", + activePosition.getLine() + 1, + activePosition.getCharacter() + 2, + 0 + ) + ) + + if (statement.isDefined) { + val (module, declaration) = statement.get + locationForExpression(declaration, module) + } else null + } + + @JsonRequest("textDocument/references") + def textDocumentReferences( + params: ReferenceParams + ): CompletableFuture[ju.List[Location]] = + CompletableFuture.completedFuture { + val activePosition = params.getPosition() + + val statement = findExpressionAtPosition( + trimDocumentUri(params.getTextDocument().getUri()), + Position( + "", + activePosition.getLine() + 1, + activePosition.getCharacter() + 2, + 0 + ) + ) + + if (statement.isDefined) { + val (declarationModule, declarationContent) = statement.get + + logEvent( + TelemetryEvent("Attempting to find references") + ) + + if ( + declarationContent + .isInstanceOf[DeclarationStatement] || declarationContent + .isInstanceOf[ParameterDeclaration] + ) { + val matchingExpressions = + // Only include declaration if params specify it + (if (params.getContext().isIncludeDeclaration()) + List((declarationModule, declarationContent)) + else List()) ++ NodeFinder + .matchingExpressionsForDeclaration( + cachedModules.toStream, + declarationContent + ) + + logEvent( + TelemetryEvent("Prepping references", matchingExpressions) + ) + + matchingExpressions + .sortBy { + case (_, expression) => + expression.position match { + case Some(value) => value.line + case None => 0 + } + } + .map { + case (module, expression) => { + try { + locationForExpression(expression, module) + } catch { + case _: Throwable => null + } + } + } + .filter(e => e != null) + .asJava + } else { + null + } + } else null + } + + @JsonRequest("textDocument/hover") + def textDocumentHover( + params: TextDocumentPositionParams + ): CompletableFuture[Hover] = + CompletableFuture.completedFuture { + val hoverPosition = params.getPosition() + + val statement = findExpressionAtPosition( + trimDocumentUri(params.getTextDocument().getUri()), + Position( + "", + // Millfork positions start at 1,2, rather than 0,0, so add to each coord + hoverPosition.getLine() + 1, + hoverPosition.getCharacter() + 2, + 0 + ) + ) + + if (statement.isDefined) { + val (_, declarationContent) = statement.get + val formatting = NodeFormatter.symbol(declarationContent) + val docstring = NodeFormatter.docstring(declarationContent) + + if (formatting.isDefined) + new Hover( + new MarkupContent( + "markdown", + NodeFormatter.hover( + formatting.get, + docstring.getOrElse("") + ) + ) + ) + else null + } else null + } + + /** + * Builds the AST for the entire program, based on the configured "inputFileNames" + */ + private def populateProgramForPath() = { + logEvent( + TelemetryEvent("Building program AST") + ) + + var program = queue.run() + + logEvent( + TelemetryEvent("Finished building AST") + ) + + cachedProgram = Some(program) + program.parsedModules.foreach { + case (moduleName, program) => + cachedModules.put(moduleName, program) + } + program.modulePaths.foreach { + case (moduleName, path) => modulePaths.put(moduleName, path) + } + } + + private def moduleNameForPath(documentPath: String) = + moduleNames.get(documentPath).getOrElse { + throw new Exception("Cannot find module at " + documentPath) + } + + private def findExpressionAtPosition( + documentPath: String, + position: Position + ): Option[(String, Node)] = { + val moduleName = moduleNameForPath(documentPath) + + val currentModuleDeclarations = cachedModules.get(moduleName) + + if (currentModuleDeclarations.isEmpty) { + return None + } + + val (node, enclosingDeclarations) = NodeFinder.findNodeAtPosition( + currentModuleDeclarations.get, + position + ) + + if (node.isDefined) { + logEvent(TelemetryEvent("Found node at position", node)) + + // Build ordered scopes to search through + // First, our current enclosing scope, then the current module (which contains the current scope), then all other modules + val orderedScopes = List( + (moduleName, enclosingDeclarations.get), + (moduleName, currentModuleDeclarations.get.declarations) + ) ++ cachedModules.toList + .filter { + case (cachedModuleName, program) => cachedModuleName != moduleName + } + .map { + case (cachedModuleName, program) => + (cachedModuleName, program.declarations) + } + + val usage = + NodeFinder.findDeclarationForUsage(orderedScopes, node.get) + + if (usage.isDefined) { + logEvent(TelemetryEvent("Found original declaration", usage)) + usage + } else Some((moduleName, node.get)) + } else { + logEvent(TelemetryEvent("Cannot find node for position", position)) + None + } + } + + /** + * Builds highlighted `Location` of a declaration or usage + */ + private def locationForExpression( + expression: Node, + module: String + ): Location = { + val name = NodeFinder.extractNodeName(expression) + val position = expression.position.get + val modulePath = modulePaths.getOrElse( + module, { + logEvent( + TelemetryEvent( + "Could not find path for module", + module + ) + ) + null + } + ) + + if (expression.isInstanceOf[ImportStatement]) { + // ImportStatement declaration is the entire "file". Set position to 1,1 + val importPosition = Position(module, 1, 1, 0) + return new Location( + modulePath.toUri().toString(), + new Range( + mfPositionToLSP4j(importPosition), + mfPositionToLSP4j(importPosition) + ) + ) + } + + val endPosition = if (name.isDefined) { + Position( + module, + position.line, + position.column + name.get.length, + 0 + ) + } else position + + new Location( + modulePath.toUri().toString(), + new Range( + mfPositionToLSP4j(position), + mfPositionToLSP4j(endPosition) + ) + ) + } + + private def queue() = + CpuFamily.forType(options.platform.cpu) match { + case CpuFamily.M6502 => + new MosSourceLoadingQueue( + initialFilenames = context.inputFileNames, + includePath = context.includePath, + options = options + ) + case CpuFamily.I80 | CpuFamily.I86 => + new ZSourceLoadingQueue( + initialFilenames = context.inputFileNames, + includePath = context.includePath, + options = options + ) + case CpuFamily.M6809 => + new MSourceLoadingQueue( + initialFilenames = context.inputFileNames, + includePath = context.includePath, + options = options + ) + } + + private def mfPositionToLSP4j( + position: Position + ): org.eclipse.lsp4j.Position = + new org.eclipse.lsp4j.Position( + position.line - 1, + // If subtracting 1 would be < 0, set to 0 + if (position.column < 1) 0 else position.column - 1 + ) + + private def logEvent(event: TelemetryEvent): Unit = { + val languageClient = client.getOrElse { + // Language client not registered + return + } + + implicit val formats = Serialization.formats(NoTypeHints) + val serializedEvent = write(event) + + languageClient.logMessage( + new MessageParams(MessageType.Log, serializedEvent) + ) + } + + private def trimDocumentUri(uri: String): String = + uri + .replaceFirst("file:(//)?", "") + // Trim Windows path oddities provided by VSCode (may not be for all LSP clients) + .replaceFirst("%3A", ":") + .replaceFirst("/([A-Za-z]):", "$1:") +} + +case class TelemetryEvent(message: String, data: Any = None) diff --git a/src/main/scala/millfork/language/NodeFinder.scala b/src/main/scala/millfork/language/NodeFinder.scala new file mode 100644 index 00000000..cfcf4615 --- /dev/null +++ b/src/main/scala/millfork/language/NodeFinder.scala @@ -0,0 +1,406 @@ +package millfork.language + +import millfork.node.{ + DeclarationStatement, + Expression, + FunctionCallExpression, + FunctionDeclarationStatement, + Node, + Program, + Position, + VariableDeclarationStatement, + VariableExpression +} +import millfork.parser.ParsedProgram + +import scala.collection.mutable +import millfork.node.ExpressionStatement +import millfork.node.ImportStatement +import millfork.node.ParameterDeclaration +import millfork.env.ByConstant +import millfork.env.ByReference +import millfork.env.ByVariable +import millfork.node.ArrayDeclarationStatement +import millfork.node.AliasDefinitionStatement +import millfork.node.IndexedExpression +import millfork.node.Statement +import millfork.node.SumExpression +import millfork.env.ByLazilyEvaluableExpressionVariable +import millfork.node.EnumDefinitionStatement +import millfork.node.LabelStatement +import millfork.node.StructDefinitionStatement +import millfork.node.TypeDefinitionStatement +import millfork.node.UnionDefinitionStatement + +object NodeFinder { + + /** + * Finds the declaration matching the provided node + * + * @param orderedScopes A list, ordered by decreasing scope (function, local module, global module), + * of tuples containing the module name and all declarations contained wherein + * @param node The node to find the source declaration for + * @return A tuple containing the module name and "declaration" (could be `ParameterDeclaration`, hence + * the `Node` type); `None` otherwise + */ + def findDeclarationForUsage( + orderedScopes: List[(String, List[DeclarationStatement])], + node: Node + ): Option[(String, Node)] = { + node match { + case importStatement: ImportStatement => + Some((importStatement.filename, importStatement)) + case expression: Expression => { + for ((moduleName, scopedDeclarations) <- orderedScopes) { + val declaration = + matchingDeclarationForExpression( + expression, + scopedDeclarations + ) + + if (declaration.isDefined) { + return Some((moduleName, declaration.get)) + } + } + + return None + } + case default => None + } + } + + /** + * Searches for the declaration matching the type and name of the provided expression + * + * @param expression The expression to find the root declaration for + * @param declarations The declarations to check + * @return The matching declaration if found; `None` otherwise + */ + private def matchingDeclarationForExpression( + expression: Expression, + declarations: List[DeclarationStatement] + ): Option[Node] = + expression match { + case FunctionCallExpression(name, expressions) => + declarations + .filter(d => d.isInstanceOf[FunctionDeclarationStatement]) + .find(d => d.name == name) + case VariableExpression(name) => + matchVariableExpressionName(name, declarations) + case IndexedExpression(name, index) => + matchVariableExpressionName(name, declarations) + case default => None + } + + /** + * Searches for the declaration matching a variable name + * + * @param name The name of the variable + * @param declarations The declarations to check + * @return The matching declaration if found; `None` otherwise + */ + private def matchVariableExpressionName( + name: String, + declarations: List[DeclarationStatement] + ) = + declarations + .flatMap(d => + d match { + // Extract nested declarations (and `ParameterDeclaration`s, which do not extend `DeclarationStatement`) + // from functions + case functionDeclaration: FunctionDeclarationStatement => + recursivelyFlatten(functionDeclaration) + .filter(e => + e.isInstanceOf[DeclarationStatement] || e + .isInstanceOf[ParameterDeclaration] + ) + case default => List(default) + } + ) + .find(d => + d match { + case variableDeclaration: VariableDeclarationStatement => + variableDeclaration.name == name + case ParameterDeclaration(typ, assemblyParamPassingConvention) => + assemblyParamPassingConvention match { + case ByConstant(pName) => pName == name + case ByReference(pName) => pName == name + case ByVariable(pName) => pName == name + case default => false + } + case arrayDeclaration: ArrayDeclarationStatement => + arrayDeclaration.name == name + case AliasDefinitionStatement(aName, target, important) => + aName == name + case default => false + } + ) + + /** + * Finds all expressions referencing a declaration + * + * @param parsedModules All program modules + * @param declaration The declaration to find all references for + * @return A list of tuples, containing the module name and the corresponding expression + */ + def matchingExpressionsForDeclaration( + parsedModules: Stream[(String, Program)], + declaration: Node + ): List[(String, Node)] = { + parsedModules.toStream.flatMap { + case (module, program) => { + val allDeclarations = + program.declarations + .flatMap(d => d.getAllExpressions) + .flatMap(flattenNestedExpressions) + + declaration match { + case f: FunctionDeclarationStatement => + allDeclarations + .filter(d => d.isInstanceOf[FunctionCallExpression]) + .map(d => d.asInstanceOf[FunctionCallExpression]) + .filter(d => d.functionName == f.name) + .map(d => (module, d)) + case v: VariableDeclarationStatement => + allDeclarations + .filter(d => d.isInstanceOf[VariableExpression]) + .map(d => d.asInstanceOf[VariableExpression]) + .filter(d => d.name == v.name) + .map(d => (module, d)) + case a: ArrayDeclarationStatement => + allDeclarations + .filter(d => d.isInstanceOf[IndexedExpression]) + .map(d => d.asInstanceOf[IndexedExpression]) + .filter(d => d.name == a.name) + .map(d => (module, d)) + case p: ParameterDeclaration => { + val pName = p.assemblyParamPassingConvention match { + case ByConstant(name) => Some(name) + case ByReference(name) => Some(name) + case ByVariable(name) => Some(name) + case ByLazilyEvaluableExpressionVariable(name) => + Some(name) + case _ => None + } + + if (pName.isDefined) { + allDeclarations + .filter(d => extractNodeName(d) == pName) + .map(d => (module, d)) + } else List() + } + case default => List() + } + } + }.toList + } + + /** + * Finds the node and enclosing declaration scope for a given position + * + * @param program The program containing the position + * @param position The position of the node to find + * @return A tuple containing the found node, and a list of enclosing declaration scopes + */ + def findNodeAtPosition( + program: Program, + position: Position + ): (Option[Node], Option[List[DeclarationStatement]]) = { + val line = position.line + val column = position.column + + val declarations = + findEnclosingDeclarationsAtLine(program.declarations, line) + + if (declarations.isEmpty) { + return (None, None) + } + + if (lineOrNegOne(declarations.get.head.position) != line) { + // Declaration is a function or similar wrapper + // Find inner expressions + if (declarations.get.length > 1) { + throw new Exception("Unexpected number of declarations") + } + + return ( + findNodeAtColumn( + recursivelyFlatten(declarations.get.head), + line, + column + ), + declarations + ) + } + + // All declarations are current line, find matching node for column + ( + findNodeAtColumn( + declarations.get + .flatMap(recursivelyFlatten), + line, + column + ), + declarations + ) + } + + /** + * Finds the narrowest top level declaration scope for the given line (typically enclosing function) + * + * @param declarations All program declarations in the file + * @param line The line to search for + */ + private def findEnclosingDeclarationsAtLine( + declarations: List[DeclarationStatement], + line: Int + ): Option[List[DeclarationStatement]] = { + var lastDeclarations: Option[List[DeclarationStatement]] = + if (declarations.length > 0) Some(declarations.take(1)) else None + + for ((nextDeclaration, i) <- declarations.view.zipWithIndex) { + val nextLine = lineOrNegOne(nextDeclaration.position) + + if (nextLine == line) { + // Declaration is on this line + // Check for additional declarations on this line + val newDeclarations = mutable.MutableList(nextDeclaration) + + for (declarationIndex <- i to declarations.length - 1) { + val checkDeclaration = declarations(declarationIndex) + + if ( + checkDeclaration.position.isDefined && checkDeclaration.position.get.line == line + ) { + newDeclarations += checkDeclaration + } else { + // Line doesn't match, done with this line + return Some(newDeclarations.toList) + } + } + + return Some(newDeclarations.toList) + } else if (nextLine < line) { + // Closer to desired line + lastDeclarations = Some(List(nextDeclaration)) + } + } + + lastDeclarations + } + + /** + * Searches for closest column index less than the selected column on a given line + */ + private def findNodeAtColumn( + nodes: List[Node], + line: Int, + column: Int + ): Option[Node] = { + var lastNode: Option[Node] = None + var lastPosition: Option[Position] = None + + // Only consider nodes on this line (if we're opening a declaration, it could span multiple lines) + var flattenedNodes = nodes.flatMap(flattenNestedExpressions) + + for (nextNode <- flattenedNodes) + if (lineOrNegOne(nextNode.position) == line) { + if (nextNode.position.isEmpty) { + throw new Error("Missing position for node " + nextNode.toString()) + } + + if ( + colOrNegOne(nextNode.position) < column && colOrNegOne( + nextNode.position + // Allow equality, because later nodes are of higher specificity + ) >= colOrNegOne(lastPosition) + ) { + lastNode = Some(nextNode) + lastPosition = nextNode.position + } + } + + lastNode + } + + private def lineOrNegOne(position: Option[Position]): Int = + position match { + case Some(pos) => pos.line + case None => -1 + } + + private def colOrNegOne(position: Option[Position]): Int = + position match { + case Some(pos) => pos.column + case None => -1 + } + + /** + * Recursively flattens a node tree into a tree containing the node itself and all of its "owned" nodes. + * In particular, function declarations don't call `getAllExpressions`, instead opting to manually pull out + * `params` and `statements` to allow for proper access + * + * @param node The root of the tree to process + * @return A flatten list of all nodes + */ + private def recursivelyFlatten(node: Node): List[Node] = + node match { + case functionDeclaration: FunctionDeclarationStatement => + List( + functionDeclaration + ) ++ functionDeclaration.params ++ functionDeclaration.statements + .getOrElse(List()) + .flatMap(recursivelyFlatten) + case statement: Statement => + List(statement) ++ statement.getAllExpressions.flatMap( + recursivelyFlatten + ) + case expression: Expression => List(expression) + case default => List(default) + } + + /** + * Returns all of the expressions contained within an expression, including itself + */ + private def flattenNestedExpressions(node: Node): List[Node] = + node match { + case statement: ExpressionStatement => { + val innerExpressions = flattenNestedExpressions( + statement.expression + ) + + List(statement.expression) ++ innerExpressions + } + case functionExpression: FunctionCallExpression => + List(functionExpression) ++ + functionExpression.expressions + .flatMap(flattenNestedExpressions) + case indexExpression: IndexedExpression => + List(indexExpression) ++ + flattenNestedExpressions(indexExpression.index) + case sumExpression: SumExpression => + List(sumExpression) ++ sumExpression.expressions.flatMap(e => + flattenNestedExpressions(e._2) + ) + case default => List(default) + } + + /** + * Returns the name of the node, if it exists + */ + def extractNodeName(node: Node): Option[String] = + node match { + case a: AliasDefinitionStatement => Some(a.name) + case a: ArrayDeclarationStatement => Some(a.name) + case e: EnumDefinitionStatement => Some(e.name) + case f: FunctionCallExpression => Some(f.functionName) + case f: FunctionDeclarationStatement => Some(f.name) + case l: LabelStatement => Some(l.name) + case s: StructDefinitionStatement => Some(s.name) + case t: TypeDefinitionStatement => Some(t.name) + case u: UnionDefinitionStatement => Some(u.name) + case v: VariableDeclarationStatement => Some(v.name) + case v: VariableExpression => Some(v.name) + case _ => None + } +} diff --git a/src/main/scala/millfork/language/NodeFormatter.scala b/src/main/scala/millfork/language/NodeFormatter.scala new file mode 100644 index 00000000..142243af --- /dev/null +++ b/src/main/scala/millfork/language/NodeFormatter.scala @@ -0,0 +1,293 @@ +package millfork.language + +import millfork.node.Node +import millfork.node.DeclarationStatement +import millfork.node.FunctionDeclarationStatement +import millfork.node.ParameterDeclaration +import millfork.node.Expression +import millfork.node.VariableExpression +import millfork.node.VariableDeclarationStatement +import millfork.node.LiteralExpression +import millfork.env.ParamPassingConvention +import millfork.env.ByConstant +import millfork.env.ByVariable +import millfork.env.ByReference +import millfork.node.ImportStatement +import millfork.env.ByLazilyEvaluableExpressionVariable +import millfork.env.ByMosRegister +import millfork.node.MosRegister +import millfork.env.ByZRegister +import millfork.node.ZRegister +import millfork.env.ByM6809Register +import millfork.node.M6809Register +import millfork.node.ArrayDeclarationStatement +import millfork.output.MemoryAlignment +import millfork.output.NoAlignment +import millfork.output.DivisibleAlignment +import millfork.output.WithinPageAlignment +import millfork.node.AliasDefinitionStatement +import java.util.regex.Pattern +import scala.collection.mutable.ListBuffer + +object NodeFormatter { + val docstringAsteriskPattern = + Pattern.compile("^\\s*\\*? *", Pattern.MULTILINE) + + val docstringParamPattern = + Pattern.compile("@param (\\w+) +(.*)$", Pattern.MULTILINE) + val docstringReturnsPattern = + Pattern.compile("@returns +(.*)$", Pattern.MULTILINE) + + // TODO: Remove Option + def symbol(node: Node): Option[String] = + node match { + case statement: DeclarationStatement => + statement match { + case functionStatement: FunctionDeclarationStatement => { + val builder = new StringBuilder() + + if (functionStatement.constPure) { + builder.append("const ") + } + + if (functionStatement.interrupt) { + builder.append("interrupt ") + } + + if (functionStatement.kernalInterrupt) { + builder.append("kernal_interrupt ") + } + + if (functionStatement.assembly) { + builder.append("asm ") + } + + // Cannot have both "macro" and "inline" + if (functionStatement.isMacro) { + builder.append("macro ") + } else if ( + functionStatement.inlinable.isDefined && functionStatement.inlinable.get + ) { + builder.append("inline ") + } + + builder.append( + s"""${functionStatement.resultType} ${functionStatement.name}(${functionStatement.params + .map(symbol) + .filter(n => n.isDefined) + .map(n => n.get) + .mkString(", ")})""" + ) + + Some(builder.toString()) + } + case variableStatement: VariableDeclarationStatement => { + val builder = new StringBuilder() + + if (variableStatement.constant) { + builder.append("const ") + } + + if (variableStatement.volatile) { + builder.append("volatile ") + } + + builder.append( + s"""${variableStatement.typ} ${variableStatement.name}""" + ) + + if (variableStatement.initialValue.isDefined) { + val formattedInitialValue = symbol( + variableStatement.initialValue.get + ) + + if (formattedInitialValue.isDefined) { + builder.append(s""" = ${formattedInitialValue.get}""") + } + } + + Some(builder.toString()) + } + case importStatement: ImportStatement => + Some(s"""import ${importStatement.filename}""") + case arrayStatement: ArrayDeclarationStatement => { + val builder = new StringBuilder() + + if (arrayStatement.const) { + builder.append("const ") + } + + builder.append( + s"""array(${arrayStatement.elementType}) ${arrayStatement.name}""" + ) + + if (arrayStatement.length.isDefined) { + val formattedLength = symbol(arrayStatement.length.get) + + if (formattedLength.isDefined) { + builder.append(s""" [${formattedLength.get}]""") + } + } + + if (arrayStatement.alignment.isDefined) { + val formattedAlignment = symbol( + arrayStatement.alignment.get + ) + + if (formattedAlignment.isDefined) { + builder.append(s""" align(${formattedAlignment.get})""") + } + } + + if (arrayStatement.address.isDefined) { + val formattedAddress = symbol(arrayStatement.address.get) + + if (formattedAddress.isDefined) { + builder.append(s""" @ ${formattedAddress.get}""") + } + } + + if (arrayStatement.elements.isDefined) { + val formattedInitialValue = arrayStatement.elements.get + .getAllExpressions(false) + .map(e => symbol(e)) + .filter(e => e.isDefined) + .map(e => e.get) + .mkString(", ") + + builder.append(s""" = [${formattedInitialValue}]""") + } + + Some(builder.toString()) + } + case AliasDefinitionStatement(name, target, important) => { + val builder = new StringBuilder() + + builder.append(s"""alias ${name} = ${target}""") + + if (important) { + builder.append("!") + } + + Some(builder.toString()) + } + // TODO: Finish + case default => None + } + case ParameterDeclaration(typ, assemblyParamPassingConvention) => + Some(s"""${typ} ${symbol(assemblyParamPassingConvention)}""") + case expression: Expression => + expression match { + case LiteralExpression(value, _) => Some(s"""${value}""") + case VariableExpression(name) => Some(s"""${name}""") + case default => None + } + case default => None + } + + def symbol(paramConvention: ParamPassingConvention): String = + paramConvention match { + case ByConstant(name) => name + case ByVariable(name) => name + case ByReference(name) => name + case ByLazilyEvaluableExpressionVariable(name) => name + case ByMosRegister(register) => + MosRegister.toString(register).getOrElse("") + case ByZRegister(register) => ZRegister.toString(register).getOrElse("") + case ByM6809Register(register) => + M6809Register.toString(register).getOrElse("") + } + + def symbol(alignment: MemoryAlignment): Option[String] = + alignment match { + case NoAlignment => None + // TOOD: Improve + case DivisibleAlignment(divisor) => Some(s"""${divisor}""") + case WithinPageAlignment => Some("Within page") + } + + def docstring(node: Node): Option[String] = { + val docComment = node match { + case f: FunctionDeclarationStatement => f.docComment + case v: VariableDeclarationStatement => v.docComment + case a: ArrayDeclarationStatement => a.docComment + case _ => None + } + + if (docComment.isEmpty) { + return None + } + + val baseString = docComment.get.text + + var strippedString = docstringAsteriskPattern + .matcher(baseString.stripSuffix("*/")) + .replaceAll("") + + val matchGroups = new ListBuffer[(String, String, Range)]() + + val paramMatcher = docstringParamPattern.matcher(strippedString) + while (paramMatcher.find()) { + matchGroups += ( + ( + paramMatcher.group(1), + paramMatcher + .group(2), + Range(paramMatcher.start(), paramMatcher.end()) + ) + ) + } + + val builder = new StringBuilder(strippedString) + + for (param <- matchGroups.reverse) { + val (paramName, description, range) = param + builder.replace( + range.start, + range.end, + s"\n_@param_ `${paramName}` \u2014 ${description.trim()}\n" + ) + } + + val returnMatch = docstringReturnsPattern.matcher(builder.toString()) + + if (returnMatch.find()) { + builder.replace( + returnMatch.start(), + returnMatch.end(), + s"\n_@returns_ \u2014 ${returnMatch.group(1).trim()}" + ) + } + + return Some(builder.toString()) + } + + /** + * Render the textDocument/hover result into markdown. + * + * @param symbolSignature The signature of the symbol over the cursor, for example + * "def map[B](fn: A => B): Option[B]" + * @param docstring The Markdown documentation string for the symbol. + */ + def hover( + symbolSignature: String, + docstring: String + ): String = { + val markdown = new StringBuilder() + + if (symbolSignature.nonEmpty) { + markdown + .append("```mfk\n") + .append(symbolSignature) + .append("\n```") + } + + if (docstring.nonEmpty) + markdown + .append("\n---\n") + .append(docstring) + .append("\n") + + markdown.toString() + } +} diff --git a/src/main/scala/millfork/node/Node.scala b/src/main/scala/millfork/node/Node.scala index ad2e019e..b0b0e3e9 100644 --- a/src/main/scala/millfork/node/Node.scala +++ b/src/main/scala/millfork/node/Node.scala @@ -253,12 +253,46 @@ object M6809NiceFunctionProperty { object MosRegister extends Enumeration { val A, X, Y, AX, AY, YA, XA, XY, YX, AW = Value + + private val registerStringToValue = Map[String, MosRegister.Value]( + "xy" -> MosRegister.XY, + "yx" -> MosRegister.YX, + "ax" -> MosRegister.AX, + "ay" -> MosRegister.AY, + "xa" -> MosRegister.XA, + "ya" -> MosRegister.YA, + "a" -> MosRegister.A, + "x" -> MosRegister.X, + "y" -> MosRegister.Y, + ) + private val registerValueToString = registerStringToValue.map { case (key, value) => (value, key)}.toMap + + def fromString(name: String): Option[MosRegister.Value] = registerStringToValue.get(name) + def toString(value: MosRegister.Value): Option[String] = registerValueToString.get(value) } object ZRegister extends Enumeration { val A, B, C, D, E, H, L, AF, BC, HL, DE, SP, IXH, IXL, IYH, IYL, IX, IY, R, I, MEM_HL, MEM_BC, MEM_DE, MEM_IX_D, MEM_IY_D, MEM_ABS_8, MEM_ABS_16, IMM_8, IMM_16 = Value + val registerStringToValue = Map[String, ZRegister.Value]( + "hl" -> ZRegister.HL, + "bc" -> ZRegister.BC, + "de" -> ZRegister.DE, + "a" -> ZRegister.A, + "b" -> ZRegister.B, + "c" -> ZRegister.C, + "d" -> ZRegister.D, + "e" -> ZRegister.E, + "h" -> ZRegister.H, + "l" -> ZRegister.L, + ) + + private val registerValueToString = registerStringToValue.map { case (key, value) => (value, key)}.toMap + + def fromString(name: String): Option[ZRegister.Value] = registerStringToValue.get(name) + def toString(value: ZRegister.Value): Option[String] = registerValueToString.get(value) + def registerSize(reg: Value): Int = reg match { case AF | BC | DE | HL | IX | IY | IMM_16 => 2 case A | B | C | D | E | H | L | IXH | IXL | IYH | IYL | R | I | IMM_8 => 1 @@ -280,6 +314,24 @@ object ZRegister extends Enumeration { object M6809Register extends Enumeration { val A, B, D, DP, X, Y, U, S, PC, CC = Value + val registerStringToValue = Map[String, M6809Register.Value]( + "x" -> M6809Register.X, + "y" -> M6809Register.Y, + "s" -> M6809Register.S, + "u" -> M6809Register.U, + "a" -> M6809Register.A, + "b" -> M6809Register.B, + "d" -> M6809Register.D, + "dp" -> M6809Register.DP, + "pc" -> M6809Register.PC, + "cc" -> M6809Register.CC, + ) + + private val registerValueToString = registerStringToValue.map { case (key, value) => (value, key)}.toMap + + def fromString(name: String): Option[M6809Register.Value] = registerStringToValue.get(name) + def toString(value: M6809Register.Value): Option[String] = registerValueToString.get(value) + def registerSize(reg: Value): Int = reg match { case D | X | Y | U | S | PC => 2 case A | B | DP | CC => 1 @@ -445,6 +497,18 @@ sealed trait Statement extends Node { sealed trait DeclarationStatement extends Statement { def name: String + var docComment: Option[DocComment] = None +} + +object DeclarationStatement { + implicit class DeclarationStatementOps[D<:DeclarationStatement](val declaration: D) extends AnyVal { + def docComment(comment: Option[DocComment]): D = { + if (comment.isDefined) { + declaration.docComment = comment + } + declaration + } + } } sealed trait BankedDeclarationStatement extends DeclarationStatement { @@ -833,4 +897,6 @@ object MosAssemblyStatement { def implied(opcode: Opcode.Value, elidability: Elidability.Value) = MosAssemblyStatement(opcode, AddrMode.Implied, LiteralExpression(0, 1), elidability) def nonexistent(opcode: Opcode.Value) = MosAssemblyStatement(opcode, AddrMode.DoesNotExist, LiteralExpression(0, 1), elidability = Elidability.Elidable) -} \ No newline at end of file +} + +case class DocComment(text: String) extends Node {} \ No newline at end of file diff --git a/src/main/scala/millfork/parser/AbstractSourceLoadingQueue.scala b/src/main/scala/millfork/parser/AbstractSourceLoadingQueue.scala index 66c1b0bf..ca971a7b 100644 --- a/src/main/scala/millfork/parser/AbstractSourceLoadingQueue.scala +++ b/src/main/scala/millfork/parser/AbstractSourceLoadingQueue.scala @@ -1,7 +1,7 @@ package millfork.parser import java.nio.charset.StandardCharsets -import java.nio.file.{Files, Paths} +import java.nio.file.{Files, Path, Paths} import fastparse.core.Parsed.{Failure, Success} import millfork.{CompilationFlag, CompilationOptions, Tarjan} @@ -10,11 +10,14 @@ import millfork.node.{AliasDefinitionStatement, DeclarationStatement, ImportStat import scala.collection.mutable import scala.collection.convert.ImplicitConversionsToScala._ +case class ParsedProgram(compilationOrderProgram: Program, parsedModules: Map[String, Program], modulePaths: Map[String, Path]) + abstract class AbstractSourceLoadingQueue[T](val initialFilenames: List[String], val includePath: List[String], val options: CompilationOptions) { protected val parsedModules: mutable.Map[String, Program] = mutable.Map[String, Program]() + protected val modulePaths: mutable.Map[String, Path] = mutable.Map[String, Path]() protected val moduleDependecies: mutable.Set[(String, String)] = mutable.Set[(String, String)]() protected val moduleQueue: mutable.Queue[() => Unit] = mutable.Queue[() => Unit]() val extension: String = ".mfk" @@ -41,7 +44,12 @@ abstract class AbstractSourceLoadingQueue[T](val initialFilenames: List[String], encodingConversionAliases } - def run(): Program = { + /** + * Tokenizes and parses the configured source file and modules + * + * @return A ParsedProgram containing an ordered set of statements in order of compilation dependencies, and each individual parsed module + */ + def run(): ParsedProgram = { for { initialFilename <- initialFilenames startingModule <- options.platform.startingModules @@ -76,7 +84,8 @@ abstract class AbstractSourceLoadingQueue[T](val initialFilenames: List[String], options.log.assertNoErrors("Parse failed") val compilationOrder = Tarjan.sort(parsedModules.keys, moduleDependecies) options.log.debug("Compilation order: " + compilationOrder.mkString(", ")) - compilationOrder.filter(parsedModules.contains).map(parsedModules).reduce(_ + _).applyImportantAliases + + ParsedProgram(compilationOrder.filter(parsedModules.contains).map(parsedModules).reduce(_ + _).applyImportantAliases, parsedModules.toMap, modulePaths.toMap) } def lookupModuleFile(includePath: List[String], moduleName: String, position: Option[Position]): String = { @@ -98,13 +107,24 @@ abstract class AbstractSourceLoadingQueue[T](val initialFilenames: List[String], if (templateParams.isEmpty) moduleNameBase else moduleNameBase + templateParams.mkString("<", ",", ">") } + /** + * Finds module path and builds module AST, adding to `parsedModules` + */ def parseModule(moduleName: String, includePath: List[String], why: Either[Option[Position], String], templateParams: List[String]): Unit = { val filename: String = why.fold(p => lookupModuleFile(includePath, moduleName, p), s => s) options.log.debug(s"Parsing $filename") val path = Paths.get(filename) + + modulePaths.put(moduleName, path) + + parseModuleWithLines(moduleName, path, Files.readAllLines(path, StandardCharsets.UTF_8).toIndexedSeq, includePath, why, templateParams) + } + + def parseModuleWithLines(moduleName: String, path: Path, lines: Seq[String], includePath: List[String], why: Either[Option[Position], String], templateParams: List[String]): Option[Program] = { val parentDir = path.toFile.getAbsoluteFile.getParent val shortFileName = path.getFileName.toString - val PreprocessingResult(src, featureConstants, pragmas) = Preprocessor(options, shortFileName, Files.readAllLines(path, StandardCharsets.UTF_8).toIndexedSeq, templateParams) + + val PreprocessingResult(src, featureConstants, pragmas) = Preprocessor(options, shortFileName, lines, templateParams) for (pragma <- pragmas) { if (!supportedPragmas(pragma._1) && options.flag(CompilationFlag.BuggyCodeWarning)) { options.log.warn(s"Unsupported pragma: #pragma ${pragma._1}", Some(Position(moduleName, pragma._2, 1, 0))) @@ -126,16 +146,19 @@ abstract class AbstractSourceLoadingQueue[T](val initialFilenames: List[String], case _ => () } } + Some(prog) case f@Failure(a, b, d) => - options.log.error(s"Failed to parse the module `$moduleName` in $filename", Some(parser.indexToPosition(f.index, parser.lastLabel))) + options.log.error(s"Failed to parse the module `$moduleName` in ${path.toString()}", Some(parser.indexToPosition(f.index, parser.lastLabel))) if (parser.lastLabel != "") { options.log.error(s"Syntax error: ${parser.lastLabel} expected", Some(parser.lastPosition)) } else { options.log.error("Syntax error", Some(parser.lastPosition)) } + None } } + // TODO: Separate from Queue def extractName(i: String): String = { val noExt = i.stripSuffix(extension) val lastSlash = noExt.lastIndexOf('/') max noExt.lastIndexOf('\\') diff --git a/src/main/scala/millfork/parser/M6809Parser.scala b/src/main/scala/millfork/parser/M6809Parser.scala index c4f960cd..7515b535 100644 --- a/src/main/scala/millfork/parser/M6809Parser.scala +++ b/src/main/scala/millfork/parser/M6809Parser.scala @@ -45,21 +45,10 @@ case class M6809Parser(filename: String, val asmOpcode: P[(MOpcode.Value, Option[MAddrMode])] = (position() ~ (letter.rep ~ ("2" | "3").?).! ).map { case (p, o) => MOpcode.lookup(o, Some(p), log) } - private def mapRegister(p: (Position, String)): M6809Register.Value = p._2.toLowerCase(Locale.ROOT) match { - case "x" => M6809Register.X - case "y" => M6809Register.Y - case "s" => M6809Register.S - case "u" => M6809Register.U - case "a" => M6809Register.A - case "b" => M6809Register.B - case "d" => M6809Register.D - case "dp" => M6809Register.DP - case "pc" => M6809Register.PC - case "cc" => M6809Register.CC - case _ => + private def mapRegister(p: (Position, String)): M6809Register.Value = M6809Register.fromString(p._2.toLowerCase(Locale.ROOT)).getOrElse({ log.error("Invalid register " + p._2, Some(p._1)) M6809Register.D - } + }) // only used for TFR, EXG, PSHS, PULS, PSHU, PULU, so it is allowed to accept any register name in order to let parsing continue: val anyRegister: P[M6809Register.Value] = P(position() ~ identifier.!).map(mapRegister) diff --git a/src/main/scala/millfork/parser/MfParser.scala b/src/main/scala/millfork/parser/MfParser.scala index 5a95d68a..e683ea19 100644 --- a/src/main/scala/millfork/parser/MfParser.scala +++ b/src/main/scala/millfork/parser/MfParser.scala @@ -83,11 +83,20 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri val comment: P[Unit] = P("//" ~ CharsWhile(c => c != '\n' && c != '\r', min = 0) ~ ("\r\n" | "\r" | "\n")) + val recursiveMultilineCommentContent: P[Unit] = P(CharsWhile(c => c != '*', min = 0) ~ ("*/" | ("*" ~ recursiveMultilineCommentContent))) + + val docComment: P[DocComment] = for { + p <- position() + text <- ("/**" ~ recursiveMultilineCommentContent.!) + } yield DocComment(text).pos(p) + + val multilineComment: P[Unit] = P("/*" ~ !"*" ~ recursiveMultilineCommentContent) + val semicolon: P[Unit] = P(";" ~ CharsWhileIn("; \t", min = 0) ~ position("line break after a semicolon").map(_ => ()) ~ (comment | "\r\n" | "\r" | "\n").opaque("")) val semicolonComment: P[Unit] = P(";" ~ CharsWhile(c => c != '\n' && c != '\r' && c != '{' && c != '}', min = 0) ~ position("line break instead of braces").map(_ => ()) ~ ("\r\n" | "\r" | "\n").opaque("")) - val AWS: P[Unit] = P((CharIn(" \t\n\r") | semicolon | comment).rep(min = 0)).opaque("") + val AWS: P[Unit] = P((CharIn(" \t\n\r") | semicolon | comment | multilineComment).rep(min = 0)).opaque("") val AWS_asm: P[Unit] = P((CharIn(" \t\n\r") | semicolonComment | comment).rep(min = 0)).opaque("") @@ -205,9 +214,9 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri case x => x.toString } | textLiteralAtom.! - val importStatement: P[Seq[ImportStatement]] = ("import" ~ !letterOrDigit ~/ SWS ~/ + val importStatement: P[Seq[ImportStatement]] = (position() ~ "import" ~ !letterOrDigit ~/ SWS ~/ identifier.rep(min = 1, sep = "/") ~ HWS ~ ("<" ~/ HWS ~/ quotedAtom.rep(min = 1, sep = HWS ~ "," ~/ HWS) ~/ HWS ~/ ">" ~/ Pass).?). - map{case (name, params) => Seq(ImportStatement(name.mkString("/"), params.getOrElse(Nil).toList))} + map{case (p, name, params) => Seq(ImportStatement(name.mkString("/"), params.getOrElse(Nil).toList).pos(p))} val optimizationHintsDeclaration: P[Set[String]] = if (options.flag(CompilationFlag.EnableInternalTestSyntax)) { @@ -235,6 +244,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri } def variableDefinition(implicitlyGlobal: Boolean): P[Seq[BankedDeclarationStatement]] = for { + docComment <- (docComment ~ EOL).? p <- position() bank <- bankDeclaration flags <- variableFlags ~ HWS @@ -250,7 +260,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri constant = flags("const"), volatile = flags("volatile"), register = flags("register"), - initialValue, addr, optimizationHints, alignment).pos(p) + initialValue, addr, optimizationHints, alignment).pos(p).docComment(docComment) } } @@ -428,6 +438,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri } val arrayDefinition: P[Seq[ArrayDeclarationStatement]] = for { + docComment <- (docComment ~ EOL).? p <- position() bank <- bankDeclaration const <- ("const".! ~ HWS).? @@ -443,7 +454,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri } yield { if (alignment1.isDefined && alignment2.isDefined) log.error(s"Cannot define the alignment multiple times", Some(p)) val alignment = alignment1.orElse(alignment2) - Seq(ArrayDeclarationStatement(name, bank, length, elementType.getOrElse("byte"), addr, const.isDefined, contents, optimizationHints, alignment, options.isBigEndian).pos(p)) + Seq(ArrayDeclarationStatement(name, bank, length, elementType.getOrElse("byte"), addr, const.isDefined, contents, optimizationHints, alignment, options.isBigEndian).pos(p).docComment(docComment)) } def tightMfExpression(allowIntelHex: Boolean, allowTopLevelIndexing: Boolean): P[Expression] = { @@ -695,6 +706,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri } yield Seq(DoWhileStatement(body.toList, Nil, condition)) val functionDefinition: P[Seq[BankedDeclarationStatement]] = for { + docComment <- (docComment ~ AWS).? p <- position() bank <- bankDeclaration flags <- functionFlags ~ HWS @@ -754,7 +766,7 @@ abstract class MfParser[T](fileId: String, input: String, currentDirectory: Stri flags("interrupt"), flags("kernal_interrupt"), flags("const") && !flags("asm"), - flags("reentrant")).pos(p)) + flags("reentrant")).pos(p).docComment(docComment)) } def validateAsmFunctionBody(p: Position, flags: Set[String], name: String, statements: Option[List[Statement]]) diff --git a/src/main/scala/millfork/parser/MosParser.scala b/src/main/scala/millfork/parser/MosParser.scala index 4a1f1f18..81199bc7 100644 --- a/src/main/scala/millfork/parser/MosParser.scala +++ b/src/main/scala/millfork/parser/MosParser.scala @@ -115,19 +115,8 @@ case class MosParser(filename: String, input: String, currentDirectory: String, val asmStatement: P[ExecutableStatement] = (position("assembly statement") ~ P(asmLabel | asmMacro | arrayContentsForAsm | asmInstruction)).map { case (p, s) => s.pos(p) } // TODO: macros - - override val appcRegister: P[ParamPassingConvention] = P(("xy" | "yx" | "ax" | "ay" | "xa" | "ya" | "a" | "x" | "y") ~ !letterOrDigit).!.map { - case "xy" => ByMosRegister(MosRegister.XY) - case "yx" => ByMosRegister(MosRegister.YX) - case "ax" => ByMosRegister(MosRegister.AX) - case "ay" => ByMosRegister(MosRegister.AY) - case "xa" => ByMosRegister(MosRegister.XA) - case "ya" => ByMosRegister(MosRegister.YA) - case "a" => ByMosRegister(MosRegister.A) - case "x" => ByMosRegister(MosRegister.X) - case "y" => ByMosRegister(MosRegister.Y) - case x => log.fatal(s"Unknown assembly parameter passing convention: `$x`") - } + override val appcRegister: P[ParamPassingConvention] = P(("xy" | "yx" | "ax" | "ay" | "xa" | "ya" | "a" | "x" | "y") ~ !letterOrDigit).! + .map(name => ByMosRegister(MosRegister.fromString(name).getOrElse(log.fatal(s"Unknown assembly parameter passing convention: `$name`")))) def validateAsmFunctionBody(p: Position, flags: Set[String], name: String, statements: Option[List[Statement]]): Unit = { if (!options.flag(CompilationFlag.BuggyCodeWarning)) return diff --git a/src/main/scala/millfork/parser/Z80Parser.scala b/src/main/scala/millfork/parser/Z80Parser.scala index 2e8b90b4..5a7c6c81 100644 --- a/src/main/scala/millfork/parser/Z80Parser.scala +++ b/src/main/scala/millfork/parser/Z80Parser.scala @@ -31,19 +31,8 @@ case class Z80Parser(filename: String, private val zero = LiteralExpression(0, 1) - override val appcRegister: P[ParamPassingConvention] = (P("hl" | "bc" | "de" | "a" | "b" | "c" | "d" | "e" | "h" | "l").! ~ !letterOrDigit).map { - case "a" => ByZRegister(ZRegister.A) - case "b" => ByZRegister(ZRegister.B) - case "c" => ByZRegister(ZRegister.C) - case "d" => ByZRegister(ZRegister.D) - case "e" => ByZRegister(ZRegister.E) - case "h" => ByZRegister(ZRegister.H) - case "l" => ByZRegister(ZRegister.L) - case "hl" => ByZRegister(ZRegister.HL) - case "bc" => ByZRegister(ZRegister.BC) - case "de" => ByZRegister(ZRegister.DE) - case x => log.fatal(s"Unknown assembly parameter passing convention: `$x`") - } + override val appcRegister: P[ParamPassingConvention] = (P("hl" | "bc" | "de" | "a" | "b" | "c" | "d" | "e" | "h" | "l").! ~ !letterOrDigit) + .map(name => ByZRegister(ZRegister.fromString(name).getOrElse(log.fatal(s"Unknown assembly parameter passing convention: `$name`")))) override val asmParamDefinition: P[ParameterDeclaration] = for { p <- position() diff --git a/src/test/scala/millfork/test/language/MfLanguageServerSuite.scala b/src/test/scala/millfork/test/language/MfLanguageServerSuite.scala new file mode 100644 index 00000000..f39471a9 --- /dev/null +++ b/src/test/scala/millfork/test/language/MfLanguageServerSuite.scala @@ -0,0 +1,143 @@ +package millfork.test.language + +import org.scalatest.{AppendedClues, FunSpec, Matchers} +import millfork.test.language.util._ +import org.eclipse.lsp4j.DidOpenTextDocumentParams +import org.eclipse.lsp4j.TextDocumentItem +import org.eclipse.lsp4j.HoverParams +import org.eclipse.lsp4j.TextDocumentIdentifier +import org.eclipse.lsp4j.Position +import java.util.regex.Pattern +import scala.collection.mutable + +class MfLanguageServerSuite extends FunSpec with Matchers with AppendedClues { + describe("hover") { + it("should find node under cursor, and its root declaration") { + val server = LanguageHelper.createServer + + LanguageHelper.openDocument( + server, + "file.mfk", + """ + | byte test + | array(byte) foo[4] + | void main() { + | test = test + 1 + | foo[1] = test + | } + """ + ) + + { + // Select `test` variable usage + val hoverParams = new HoverParams( + new TextDocumentIdentifier("file.mfk"), + new Position(4, 3) + ) + val response = server.textDocumentHover(hoverParams) + val hover = response.get + + val contents = hover.getContents().getRight() + contents should not equal (null) + contents.getValue() should equal( + LanguageHelper.formatHover("byte test") + ) + } + { + // Select `main` function + val hoverParams = new HoverParams( + new TextDocumentIdentifier("file.mfk"), + new Position(3, 3) + ) + val response = server.textDocumentHover(hoverParams) + val hover = response.get + + val contents = hover.getContents().getRight() + contents should not equal (null) + contents.getValue() should equal( + LanguageHelper.formatHover("void main()") + ) + } + { + // Select `foo` array usage + val hoverParams = new HoverParams( + new TextDocumentIdentifier("file.mfk"), + new Position(5, 6) + ) + val response = server.textDocumentHover(hoverParams) + val hover = response.get + + val contents = hover.getContents().getRight() + contents should not equal (null) + contents.getValue() should equal( + LanguageHelper.formatHover("array(byte) foo [4]") + ) + } + } + + describe("should always produce value") { + val server = LanguageHelper.createServer + + val text = """ + | byte test + | array(byte) foo[4] + | void main() { + | test += test + | foo[1] = test + | func() + | } + | byte func() { + | byte i + | byte innerValue + | innerValue = 2 + | innerValue += innerValue + | return innerValue + | } + """.stripMargin + + LanguageHelper.openDocument( + server, + "file.mfk", + text + ) + + val lines = text.split("\n") + + val pattern = Pattern.compile("(return|byte)") + + for ((line, i) <- lines.zipWithIndex) { + val matcher = pattern.matcher(line) + + val ignoreRanges = mutable.MutableList[Range]() + + while (matcher.find()) { + ignoreRanges += Range(matcher.start(), matcher.end()) + } + + for ((character, column) <- line.toCharArray().zipWithIndex) { + if ( + Character.isLetter(character) && + // Ignore sections of string matching pattern + ignoreRanges.filter(r => r.contains(column)).length == 0 + ) { + it(s"""should work on ${i}, ${column} contents "${line}" """) { + val hoverParams = new HoverParams( + new TextDocumentIdentifier("file.mfk"), + new Position(i, column + 2) + ) + val response = server.textDocumentHover(hoverParams) + val hover = response.get + + hover should not equal (null) + + val contents = hover.getContents().getRight() + info(contents.toString()) + contents should not equal (null) + } + } + } + } + } + + } +} diff --git a/src/test/scala/millfork/test/language/NodeFinderSuite.scala b/src/test/scala/millfork/test/language/NodeFinderSuite.scala new file mode 100644 index 00000000..27ed2072 --- /dev/null +++ b/src/test/scala/millfork/test/language/NodeFinderSuite.scala @@ -0,0 +1,356 @@ +package millfork.test.language + +import org.scalatest.{AppendedClues, FunSpec, Matchers} +import millfork.test.language.util._ +import org.eclipse.lsp4j.DidOpenTextDocumentParams +import org.eclipse.lsp4j.TextDocumentItem +import org.eclipse.lsp4j.HoverParams +import org.eclipse.lsp4j.TextDocumentIdentifier +import millfork.language.NodeFinder +import millfork.node.Position +import millfork.node.FunctionDeclarationStatement +import millfork.node.ExpressionStatement +import millfork.node.FunctionCallExpression +import millfork.node.Assignment +import java.util.regex.Pattern +import millfork.node.Program +import millfork.node.IndexedExpression +import millfork.node.SumExpression + +class NodeFinderSuite extends FunSpec with Matchers with AppendedClues { + def createProgram(text: String): Program = { + val server = LanguageHelper.createServer + + LanguageHelper + .openDocument( + server, + "file.mfk", + text + ) + + server.cachedModules.get("file").get + } + + describe("nodeAtPosition") { + val text = """ + | + | byte test + | array(byte) foo[4] + | void main() { + | test += test + | foo[1] = test + | func() + | } + | byte func(byte arg) { + | byte i + | byte innerValue + | innerValue = 2 + | innerValue += innerValue + | innerValue += arg + | return innerValue + | } + """.stripMargin + + val program = createProgram(text) + + def findRangeOfString( + text: String, + textMatch: String, + afterLine: Int = 0 + ): (Int, Range) = { + val pattern = Pattern.compile(s"(${Pattern.quote(textMatch)})") + + val lines = text.split("\n") + for ((line, i) <- lines.zipWithIndex) { + if (i >= afterLine) { + val matcher = pattern.matcher(line) + + if (matcher.find()) { + return (i + 1, Range(matcher.start() + 2, matcher.end() + 2)) + } + } + } + + throw new Error(s"Cound not find pattern ${textMatch}") + } + + it("should find root variable declarations") { + val (line, range) = findRangeOfString(text, "test") + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._2 + .get(0) should equal( + program.declarations(0) + ) + } + } + + it("should find root array declarations") { + val (line, range) = findRangeOfString(text, "foo[4]") + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._2 + .get(0) should equal( + program.declarations(1) + ) + } + } + + it("should find function declarations") { + val (line, range) = findRangeOfString(text, "main()") + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._2 + .get(0) should equal( + program.declarations(2) + ) + } + } + + it("should find variable expression within function") { + val (line, range) = findRangeOfString(text, "test", 4) + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._1 + .get should equal( + program + .declarations(2) + .asInstanceOf[FunctionDeclarationStatement] + .statements + .get(0) + .asInstanceOf[ExpressionStatement] + .expression + .asInstanceOf[FunctionCallExpression] + .expressions(0) + ) + } + } + + it("should find array expression within function") { + val (line, range) = findRangeOfString(text, "foo", 4) + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._1 + .get should equal( + program + .declarations(2) + .asInstanceOf[FunctionDeclarationStatement] + .statements + .get(1) + .asInstanceOf[Assignment] + .destination + ) + } + } + + it("should find right hand side of assignment") { + val (line, range) = findRangeOfString(text, "test", 5) + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._1 + .get should equal( + program + .declarations(2) + .asInstanceOf[FunctionDeclarationStatement] + .statements + .get(1) + .asInstanceOf[Assignment] + .source + ) + } + } + + it("should find function call") { + val (line, range) = findRangeOfString(text, "func()") + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._1 + .get should equal( + program + .declarations(2) + .asInstanceOf[FunctionDeclarationStatement] + .statements + .get(2) + .asInstanceOf[ExpressionStatement] + .expression + ) + } + } + + it("should find function argument") { + val (line, range) = findRangeOfString(text, "arg") + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._1 + .get should equal( + program + .declarations(3) + .asInstanceOf[FunctionDeclarationStatement] + .params(0) + ) + } + } + + it("should find function nested variable declarations") { + val (line, range) = findRangeOfString(text, "i", 7) + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._1 + .get should equal( + program + .declarations(3) + .asInstanceOf[FunctionDeclarationStatement] + .statements + .get(0) + ) + } + } + + it("should find variable used to index array") { + val innerText = """ + | byte root + | array(byte) anArray[10] + | void main() { + | byte index + | index = 4 + | root = anArray[index] + | index = anArray[root+1] + | } + """.stripMargin + + val program = createProgram(innerText) + + { + // Standard indexing + val (line, range) = findRangeOfString(innerText, "index", 6) + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._1 + .get should equal( + program + .declarations(2) + .asInstanceOf[FunctionDeclarationStatement] + .statements + .get(2) + .asInstanceOf[Assignment] + .source + .asInstanceOf[IndexedExpression] + .index + ) + } + } + { + // Indexing within sum expression + val (line, range) = findRangeOfString(innerText, "root", 7) + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._1 + .get should equal( + program + .declarations(2) + .asInstanceOf[FunctionDeclarationStatement] + .statements + .get(3) + .asInstanceOf[Assignment] + .source + .asInstanceOf[IndexedExpression] + .index + .asInstanceOf[SumExpression] + .expressions(0) + ._2 + ) + } + } + } + + it("should find variables within a sum") { + val innerText = """ + | byte valA + | byte valB + | byte valC + | byte output + | void main() { + | output = valB + valA - 102 + valC + | } + """.stripMargin + + val program = createProgram(innerText) + + val sumExpression = program + .declarations(4) + .asInstanceOf[FunctionDeclarationStatement] + .statements + .get(0) + .asInstanceOf[Assignment] + .source + .asInstanceOf[SumExpression] + + { + val (line, range) = findRangeOfString(innerText, "valA", 5) + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._1 + .get should equal( + sumExpression.expressions(1)._2 + ) + } + } + { + val (line, range) = findRangeOfString(innerText, "valB", 5) + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._1 + .get should equal( + sumExpression.expressions(0)._2 + ) + } + } + { + val (line, range) = findRangeOfString(innerText, "valC", 5) + + for (column <- range) { + NodeFinder + .findNodeAtPosition(program, Position("", line, column, 0)) + ._1 + .get should equal( + sumExpression.expressions(3)._2 + ) + } + } + } + + // TODO: Additional tests: + // Fields on array indexing: spawn_info[index].hi + // Struct type: Player player + // Struct fields: player1.pos + // Messed up hover positions (in `nes_reset_joy.mfk`, each variable assignment to 0) + // Alias references + // Pointers: obj_ptr->xvel + } +} diff --git a/src/test/scala/millfork/test/language/util/LanguageHelper.scala b/src/test/scala/millfork/test/language/util/LanguageHelper.scala new file mode 100644 index 00000000..f981dc98 --- /dev/null +++ b/src/test/scala/millfork/test/language/util/LanguageHelper.scala @@ -0,0 +1,45 @@ +package millfork.test.language.util + +import millfork.error.Logger +import millfork.{NullLogger, Platform} +import millfork.language.MfLanguageServer +import millfork.Context +import millfork.CompilationOptions +import millfork.test.emu.{EmuPlatform, TestErrorReporting} +import millfork.Cpu +import millfork.JobContext +import millfork.compiler.LabelGenerator +import org.eclipse.lsp4j.DidOpenTextDocumentParams +import org.eclipse.lsp4j.TextDocumentItem +import org.eclipse.lsp4j.HoverParams +import org.eclipse.lsp4j.TextDocumentIdentifier +import org.eclipse.lsp4j.Position + +object LanguageHelper { + def createServer(): MfLanguageServer = { + implicit val logger: Logger = new NullLogger() + val platform = EmuPlatform.get(Cpu.Mos) + val jobContext = JobContext(TestErrorReporting.log, new LabelGenerator) + new MfLanguageServer( + new Context(logger, List()), + new CompilationOptions( + platform, + Map(), + None, + 0, + Map(), + EmuPlatform.textCodecRepository, + jobContext + ) + ) + } + + def openDocument(server: MfLanguageServer, name: String, text: String) = { + val textDocument = + new TextDocumentItem(name, "millfork", 1, text.stripMargin) + val openParams = new DidOpenTextDocumentParams(textDocument) + server.textDocumentDidOpen(openParams) + } + + def formatHover(text: String): String = s"""```mfk\n${text}\n```""" +}