Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,25 @@ jobs:
tags: ${{ steps.s3fs-meta.outputs.tags }}
labels: ${{ steps.s3fs-meta.outputs.labels }}

- name: Extract tensorrt metadata
id: tensorrt-meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}/tensorrt
tags: |
type=ref,event=tag
type=raw,value=latest,enable=${{ github.ref_type == 'tag' || github.ref == 'refs/heads/main' }}
type=sha,prefix=

- name: Build and push tensorrt
uses: docker/build-push-action@v5
with:
context: examples/tensorrt
platforms: linux/amd64
push: true
tags: ${{ steps.tensorrt-meta.outputs.tags }}
labels: ${{ steps.tensorrt-meta.outputs.labels }}

- name: Notify release result
if: always() && startsWith(github.ref, 'refs/tags/v')
uses: marimo-team/internal-gh-actions/release-notification@main
Expand Down
131 changes: 131 additions & 0 deletions examples/jax/jax_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "flax>=0.10.0",
# "jax[cuda12]>=0.4.0",
# "marimo>=0.21.1",
# "matplotlib==3.10.8",
# "mofresh",
# "optax>=0.2.0",
# "polars>=1.0",
# ]
#
# [tool.marimo.runtime]
# auto_instantiate = false
#
# [tool.marimo.k8s]
# storage = "5Gi"
#
# [tool.marimo.k8s.resources]
# limits."nvidia.com/gpu" = 1
#
# [tool.marimo.k8s.nodeSelector]
# "gpu.nvidia.com/class" = "L40"
# ///

import marimo

__generated_with = "0.21.1"
app = marimo.App(width="columns", auto_download=["html"])

with app.setup(hide_code=True):
import marimo as mo
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import polars as pl
import numpy as np


@app.cell(hide_code=True)
def _():
mo.md(r"""
marimo's multi-column layout offers a unique live preview as you work.
""")
return


@app.cell
def _():
import matplotlib.pylab as plt
from mofresh import refresh_matplotlib, ImageRefreshWidget

widget = ImageRefreshWidget(src="")

@refresh_matplotlib
def losschart(data):
df = pl.DataFrame(data)
plt.plot(df["epoch"], df["loss_train"])

widget
return losschart, widget


@app.cell
def _():
return


@app.cell(column=1, hide_code=True)
def _():
mo.md(r"""
Example network and training with `jax`
""")
return


@app.cell
def _(losschart, widget):
datalogs = []

def train_identity_network(data, epochs=5000, learning_rate=0.025, iteration=0):
rng = jax.random.PRNGKey(42)
x = jnp.array(data, dtype=jnp.float32)

model = IdentityNetwork()
params = model.init(rng, x)

optimizer = optax.sgd(learning_rate)
opt_state = optimizer.init(params)

def loss_fn(p, batch):
return jnp.mean((model.apply(p, batch) - batch) ** 2)

for epoch in range(epochs):
loss, grads = jax.value_and_grad(loss_fn)(params, x)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)

if (epoch + 1) % 25 == 0:
test_data = jnp.array(np.random.rand(n, k), dtype=jnp.float32)
test_loss = loss_fn(params, test_data)
datalogs.append({
"epoch": epoch + iteration * epochs,
"loss_train": float(loss),
"loss_test": float(test_loss),
})
widget.src = losschart(datalogs)

return params

for i in range(1):
n = 10000
k = 10
random_data = np.random.rand(n, k)
trained_params = train_identity_network(random_data, epochs=5000, iteration=i)
return


@app.class_definition
class IdentityNetwork(nn.Module):
@nn.compact
def __call__(self, x):
features = x.shape[-1]
x = nn.relu(nn.Dense(features)(x))
x = nn.relu(nn.Dense(features)(x))
return x


if __name__ == "__main__":
app.run()
21 changes: 21 additions & 0 deletions examples/tensorrt/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
FROM nvcr.io/nvidia/tritonserver:25.01-trtllm-python-py3

# The marimo operator requires:
# 1. uv — setup-venv init container runs: uv venv /opt/venv
# 2. marimo binary in PATH — main container command is: marimo run/edit ...
RUN pip install uv "marimo>=0.21.1"

# The operator launches the container without running the NVIDIA entrypoint, so
# env vars it would normally set must be baked in here instead.

# uv is pip-installed to /usr/local/bin; operator hardcodes /usr/bin
ENV UV="/usr/local/bin/uv"

# TensorRT-LLM is in dist-packages (Ubuntu pip convention); operator venv lands at /opt/venv
ENV PYTHONPATH="/usr/local/lib/python3.12/dist-packages:/opt/venv/lib/python3.12/site-packages"

# TensorRT/cuBLAS .so files (entrypoint normally adds these to LD_LIBRARY_PATH)
ENV LD_LIBRARY_PATH="/usr/local/tensorrt/targets/x86_64-linux-gnu/lib:/usr/local/cuda/lib64:/usr/local/cuda/compat/lib:/usr/local/lib/python3.12/dist-packages/tensorrt_llm/libs:/usr/local/lib/python3.12/dist-packages/torch/lib:/usr/local/lib/python3.12/dist-packages/nvidia/cublas/lib:/usr/local/lib/python3.12/dist-packages/nvidia/cuda_runtime/lib"

# Reduce TRT-LLM log noise by default; notebooks can override if needed.
ENV TRTLLM_LOG_LEVEL="WARNING"
62 changes: 62 additions & 0 deletions examples/tensorrt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# TensorRT-LLM Example

Interactive [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) inference notebook running on the marimo operator. Includes a model picker with several open-weight models and a reactive prompt selector.

Requires a GPU node with ≥ 24 GB VRAM. The included models range from ~3 GB (TinyLlama 1.1B) up to ~16 GB (Mistral 7B / Minitron 8B).

## Files

| File | Purpose |
|---|---|
| `tensorrt.py` | marimo notebook — deploy with the `kubectl-marimo` plugin |
| `tensorrt.yaml` | Plain `MarimoNotebook` manifest — deploy with `kubectl apply` |
| `Dockerfile` | Thin layer over the NVIDIA NGC TRT-LLM container |

## Deploy with the plugin

```bash
uv tool install kubectl-marimo

kubectl marimo edit tensorrt.py -n <namespace>
```

## Deploy with kubectl

```bash
# Edit nodeSelector if needed, then apply
kubectl apply -f tensorrt.yaml -n <namespace>

# Get the access token
kubectl logs -n <namespace> tensorrt -c marimo | grep access_token

# Port-forward to access locally
kubectl port-forward -n <namespace> svc/tensorrt 2718:2718
```

## Build the image

The pre-built image is at `ghcr.io/marimo-team/marimo-operator/tensorrt:latest` and is rebuilt on every push to main. To build your own:

```bash
docker build --platform linux/amd64 -t <your-registry>/tensorrt:latest .
docker push <your-registry>/tensorrt:latest
```

If your registry is private, create an image pull secret:

```bash
kubectl create secret docker-registry registry-secret \
--docker-server=<registry> \
--docker-username=<user> \
--docker-password=<token> \
-n <namespace>

kubectl patch serviceaccount default -n <namespace> \
-p '{"imagePullSecrets":[{"name":"registry-secret"}]}'
```

## Further reading

- [TensorRT-LLM documentation](https://nvidia.github.io/TensorRT-LLM/)
- [FP8 quantization guide](https://nvidia.github.io/TensorRT-LLM/performance/performance-tuning-guide/fp8-quantization.html)
- [marimo operator](https://github.com/marimo-team/marimo-operator)
107 changes: 107 additions & 0 deletions examples/tensorrt/tensorrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# /// script
# dependencies = ["marimo>=0.21.1", "setuptools"]
#
# [tool.marimo.k8s]
# image = "ghcr.io/marimo-team/marimo-operator/tensorrt:latest"
# storage = "20Gi"
#
# [tool.marimo.k8s.resources]
# limits."nvidia.com/gpu" = 1
#
# [tool.marimo.k8s.nodeSelector]
# "gpu.nvidia.com/class" = "L40"
#
# ///

import marimo

__generated_with = "0.21.1"
app = marimo.App(width="medium")

with app.setup:
import gc
import io
import logging

import marimo as mo
import torch
from tensorrt_llm import LLM, SamplingParams


@app.cell
def _():
# Collect TRT-LLM / torch log lines emitted during model load so we can
# surface them in the notebook rather than losing them to pod stdout.
log_stream = io.StringIO()
logging.basicConfig(stream=log_stream, level=logging.WARNING)

models = {
"TinyLlama 1.1B (fast)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Phi-3.5-mini 3.8B": "microsoft/Phi-3.5-mini-instruct",
"Mistral 7B": "mistralai/Mistral-7B-Instruct-v0.3",
"Llama-3.1 8B FP8 (NVIDIA)": "nvidia/Llama-3.1-8B-Instruct-FP8",
"Minitron 8B (NVIDIA)": "nvidia/Mistral-NeMo-Minitron-8B-Instruct",
}

prompts = [
"Hello, my name is",
"The capital of France is",
"The future of AI is",
]

prompt = mo.ui.dropdown(
label="Select a Prompt",
options=prompts,
value=prompts[0],
)

model_picker = mo.ui.dropdown(
label="Model",
options=models,
value="TinyLlama 1.1B (fast)",
)

mo.vstack([
mo.md("## TensorRT-LLM Inference Example, select a model to download and run"),
model_picker,
])
return log_stream, model_picker, models, prompt, prompts


@app.cell
def _(log_stream, model_picker):
mo.stop(model_picker.value is None)

# Free the previous model's GPU memory before loading the next one.
# Without this, two models can coexist briefly and OOM a 48 GB L40.
gc.collect()
torch.cuda.empty_cache()

# Model can be a HuggingFace model name, a local path, or a quantized
# checkpoint such as nvidia/Llama-3.1-8B-Instruct-FP8 on HF.
llm = LLM(model=model_picker.value)

# Show any WARNING+ log lines collected during model load.
logs = log_stream.getvalue()
mo.accordion({"Model load logs": mo.plain_text(logs) if logs else mo.md("_No warnings._")})
return (llm,)


@app.cell(hide_code=True)
def _(llm, prompt):
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

results = []
for output in llm.generate([prompt.value], sampling_params):
results.append(
mo.md(
f"**Prompt:** {output.prompt} \n"
f"**Generated:** {output.outputs[0].text}"
)
)

mo.vstack([prompt, mo.vstack(results)])


if __name__ == "__main__":
app.run()
Loading
Loading