Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 59 additions & 7 deletions backend/commonir/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,73 @@ def from_database(
result_idx,
target,
func_or_mod,
kernel_global_source,
host_kernel_source,
kernel_lib_path,
pass_configs,
):
return cls.compile_and_create_adapter(func_or_mod)

@classmethod
def _tilelang_to_commonir(cls, tilelang_module):
from tilelang.engine import lower
from tilelang.engine.lower import extrac_params
from tilelang.engine.param import CompiledArtifact
from tilelang.engine.phase import (
PreLowerSemanticCheck,
LowerAndLegalize,
)
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir, IRModule
from tvm.ir.instrument import PrintAfterAll, PrintBeforeAll

def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.RemoveNoOp()(mod)
return mod

def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
return tir.transform.BindTarget(target)(mod)

def canon_target_host(target: str | Target, target_host: str | Target | None):
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "c"

return target_host

def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
return tvm.ffi.get_global_func("target.build.tilelang_commonir")(
device_mod, target
)

def lower(
func_or_mod: tir.PrimFunc | tvm.IRModule,
target: str | Target = "auto",
target_host: str | Target | None = None,
runtime_only=False,
) -> CompiledArtifact:
mod = func_or_mod
params = None
if isinstance(func_or_mod, tir.PrimFunc):
func = func_or_mod
params = extrac_params(func) if not runtime_only else None
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
target = "commonir"
target_host = canon_target_host(target, target_host)
target_host = tvm.target.Target.canon_target(target_host)
target = tvm.target.Target(target, target_host)
# Before lowering, do semantic check
PreLowerSemanticCheck(mod)
# Phase 1: Lower and legalize the IR
mod = LowerAndLegalize(mod, target)
# Phase 2: Optimize the IR for the target
mod = OptimizeForTarget(mod, target)
codegen_mod = device_codegen(mod, target)
# print(codegen_mod.inspect_source())
return CompiledArtifact(
None, codegen_mod, params, codegen_mod.inspect_source(), None
)

debug_enabled = os.environ.get("TILELANG_PRINT_COMMONIR", "0") in (
"1",
"true",
Expand All @@ -98,11 +153,8 @@ def _tilelang_to_commonir(cls, tilelang_module):

instruments = [PrintAfterAll(), PrintBeforeAll()] if debug_enabled else []
with tvm.transform.PassContext(instruments=instruments):
mlir_path = lower(tilelang_module)
if mlir_path.endswith(".mlir"):
mlir_content = cls._read_mlir_file(mlir_path)
else:
mlir_content = mlir_path
lower_result = lower(tilelang_module)
mlir_content = lower_result.kernel_source
return mlir_content

@classmethod
Expand Down
31 changes: 31 additions & 0 deletions cmake/commonir.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

set(GENERATED_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src/target)
file(MAKE_DIRECTORY ${GENERATED_SRC_DIR})

set(COMMONIR_SOURCE_FILES
codegen_commonir.cc
codegen_commonir.h
rt_mod_commonir.cc
)

set(GENERATED_SRCS "")
foreach(file_name IN LISTS COMMONIR_SOURCE_FILES)
set(src_path ${CMAKE_CURRENT_LIST_DIR}/../commonir/src/target/${file_name})
set(dst_path ${GENERATED_SRC_DIR}/${file_name})

add_custom_command(
OUTPUT ${dst_path}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${src_path} ${dst_path}
DEPENDS ${src_path}
COMMENT "Generating ${file_name} from CommonIR"
VERBATIM
)
list(APPEND GENERATED_SRCS ${dst_path})
endforeach()

set(TILE_LANG_COMMONIR_SRCS
${GENERATED_SRCS}
# ${CMAKE_CURRENT_SOURCE_DIR}/src/target/codegen_commonir.cc
# ${CMAKE_CURRENT_SOURCE_DIR}/src/target/rt_mod_commonir.cc
)
list(APPEND TILE_LANG_SRCS ${TILE_LANG_COMMONIR_SRCS})
Loading