Skip to content

Add trainium, inferentia, and efa parameters to @kubernetes decorator#3086

Open
emattia wants to merge 7 commits into
Netflix:masterfrom
emattia:trn-k8s
Open

Add trainium, inferentia, and efa parameters to @kubernetes decorator#3086
emattia wants to merge 7 commits into
Netflix:masterfrom
emattia:trn-k8s

Conversation

@emattia
Copy link
Copy Markdown
Contributor

@emattia emattia commented Apr 8, 2026

PR Type

  • Bug fix
  • New feature
  • Core Runtime change
  • Docs / tooling
  • Refactoring

Summary

Mirror @batch's AWS-accelerator surface on @kubernetes:

  • @kubernetes(trainium=N) requests N AWS Trainium / Inferentia Neuron
    devices (aws.amazon.com/neuron k8s resource).
  • @kubernetes(inferentia=N) is an alias for trainium, mirroring
    @batch(inferentia=N) for API consistency.
  • @kubernetes(efa=N) requests N AWS Elastic Fabric Adapter network
    interfaces (vpc.amazonaws.com/efa k8s resource).

Plumbed through kubernetes_job, kubernetes_jobsets, kubernetes_cli,
and the argo / airflow runtimes consistently with how the existing
gpu parameter is handled.

Issue

No tracking issue. Supersedes the original PR scope of just trainium.
Brings the @kubernetes path to parity with @batch for AWS Neuron
and EFA workloads, unblocking customers who run their own EKS clusters
and want first-class Neuron/EFA support without writing raw pod specs.

Reproduction

Runtime: kubernetes (EKS with AWS Neuron and EFA device plugins installed; nodes labeled with the relevant accelerator).

Commands to run:

from metaflow import FlowSpec, step, kubernetes, environment

NEURON_IMG = "public.ecr.aws/neuron/pytorch-training-neuronx:2.9.0-neuronx-py312-sdk2.29.1-ubuntu24.04"

class NeuronEfaSmoke(FlowSpec):

    @kubernetes(trainium=1, image=NEURON_IMG)
    @step
    def neuron_only(self):
        import subprocess
        print(subprocess.check_output(["neuron-ls"]).decode())
        self.next(self.gpu_efa)

    # Equivalent - inferentia is an alias for trainium
    @kubernetes(inferentia=1, image=NEURON_IMG)
    @step
    def inferentia_alias(self):
        ...

    @environment(vars={"FI_PROVIDER": "efa"})
    @kubernetes(gpu=8, efa=32, image="<aws-dlc-pytorch-cuda>")
    @step
    def gpu_with_efa(self):
        import torch.distributed as dist
        dist.init_process_group(backend="nccl")
        # NCCL debug log will show "Selected provider is efa"

Where evidence shows up: task pod spec (kubectl describe pod) and
NCCL debug log inside the running container.

Before (master)
TypeError: kubernetes() got an unexpected keyword argument 'trainium'

(also for inferentia, efa)

After (this PR)
$ kubectl describe pod ws-...
...
Limits:
  aws.amazon.com/neuron:    1
  vpc.amazonaws.com/efa:    32

Root Cause

Not a bug fix — net-new feature. The underlying Kubernetes resources
(aws.amazon.com/neuron, vpc.amazonaws.com/efa) are advertised by the
respective AWS device plugins; @kubernetes had no decorator-level
surface to request them. @batch already exposed trainium,
inferentia, and efa. This PR brings @kubernetes to parity.

Why This Fix Is Correct

  • Mirrors @batch's API surface exactly. inferentia collapses into
    trainium at step_init and is popped before any runtime translation
    — same shape as batch_decorator.py:175-211, only with trainium as
    canonical (since on K8s the underlying resource name is
    aws.amazon.com/neuron and we surface what users running on Trainium
    hardware naturally type first).
  • Doesn't disturb the existing GPU path. gpu and trainium are
    enforced as mutually exclusive (matching @batch's convention).
  • Argo/Airflow runtimes already had the trainium plumbing pattern from
    earlier in this branch; efa follows the same pattern.

Failure Modes Considered

  1. Backward compat: flows using only gpu / gpu_vendor are
    unaffected — new attributes default to None and resource-limit
    emission is gated on non-None values.
  2. Mutual exclusion: specifying both inferentia and trainium
    raises a clear error in step_init (mirrors @batch). Specifying
    both gpu and trainium was already enforced.
  3. Wire format consistency: inferentia is popped from
    self.attributes after collapsing into trainium, so the runtime
    CLI / argo / airflow translation only ever sees the canonical key.
  4. Cross-runtime: changes propagate through kubernetes_job,
    kubernetes_jobsets, argo, and airflow consistently with how
    trainium was already plumbed.
  5. Validation: efa value validated as positive integer (mirrors
    trainium and tmpfs_size validation patterns in the same file).

Tests

  • Unit tests added/updated
  • Manual reproduction provided above
  • Smoke-tested end-to-end on a real EKS cluster with Neuron and
    EFA device plugins. Pod spec contains the right resource limits;
    NCCL via aws-ofi-nccl selects EFA as the network backend.
  • CI passes — TBD (CI doesn't have AWS Trainium/EFA hardware).

Non-Goals

  • Not touching @batch (already has these parameters).
  • Not adding a --inferentia CLI flag — inferentia is purely a
    decorator-time convenience that resolves to trainium before any
    CLI invocation, mirroring @batch's CLI which only exposes the
    canonical name (--inferentia for batch since inferentia is
    canonical there; --trainium for k8s since trainium is canonical
    here).
  • Not adding NCCL/libfabric environment-variable defaults
    (FI_PROVIDER, FI_EFA_USE_DEVICE_RDMA). Users set those via
    @environment for now; auto-injection is a separate ergonomics PR.
  • Not opining on which Trainium/Inferentia instance type a user
    should target — that's a cluster-side concern (instance allowlist
    • AMI selection on the EKS managed nodegroup side).

AI Tool Usage

  • AI tools were used (Anthropic Claude - research on AWS DLC tag
    selection, Karpenter EFA NIC layout prior art, and drafting this
    PR description). All generated code reviewed, understood, and
    tested end-to-end on a live Outerbounds cluster.

@emattia emattia marked this pull request as draft April 8, 2026 00:32
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 8, 2026

Greptile Summary

This PR adds trainium, inferentia, and efa parameters to the @kubernetes decorator, bringing it to parity with @batch for AWS Neuron and EFA workloads. The implementation is consistent and well-structured across all affected runtimes.

  • trainium/inferentia: inferentia is a decorator-time alias that collapses to trainium before any CLI or runtime translation. Both map to the aws.amazon.com/neuron extended resource limit, with an auto-injected aws.amazon.com/neuron:NoSchedule toleration propagated through kubernetes_job, kubernetes_jobsets, Argo, and Airflow paths.
  • efa: Maps to vpc.amazonaws.com/efa resource limits. No EFA-specific toleration is injected (EFA nodes in EKS do not carry a standard device-plugin taint), which is correct.
  • Validation guards (positive integer, gpu/trainium mutual exclusion) are applied at step_init time, mirroring the tmpfs_size and gpu patterns already in the decorator.

Confidence Score: 5/5

Safe to merge. The changes are additive, default to None, and follow the exact same patterns as the existing GPU implementation across all runtimes.

All new parameters default to None and are gated before any resource or toleration is emitted. Validation (positive integer, gpu/trainium mutual exclusion) runs at step_init time. The inferentia-to-trainium aliasing is correctly handled before CLI serialization, and None values are safely skipped by the CLI arg builder. The implementation is consistent across kubernetes_job, kubernetes_jobsets, Argo, and Airflow paths.

No files require special attention.

Important Files Changed

Filename Overview
metaflow/plugins/kubernetes/kubernetes_decorator.py Adds trainium/inferentia/efa attributes with proper aliasing, mutual-exclusion validation, and positive-integer guards; consistent with existing gpu handling patterns.
metaflow/plugins/kubernetes/kubernetes_job.py Adds neuron and EFA resource limits and auto-injects the aws.amazon.com/neuron:NoSchedule toleration when trainium is set; mirrors the existing GPU limit pattern.
metaflow/plugins/kubernetes/kubernetes_jobsets.py Same neuron/EFA resource-limit and toleration additions as kubernetes_job.py, consistently applied to the JobSet path.
metaflow/plugins/argo/argo_workflows.py Threads trainium and efa through the Argo template and jobset paths; adds neuron toleration consistently; resource limits follow the same pattern as GPU.
metaflow/plugins/airflow/airflow.py Adds neuron and EFA resource limits to the Airflow KubernetesPodOperator resources dict and injects the neuron toleration when trainium is set.
metaflow/plugins/kubernetes/kubernetes_cli.py Adds --trainium and --efa Click options and threads them through the step function; no --inferentia CLI flag by design (decorator-time alias only).
metaflow/plugins/kubernetes/kubernetes.py Adds trainium and efa keyword arguments to both execute() overloads and forwards them to the job/jobset constructors.

Reviews (3): Last reviewed commit: "Make @kubernetes inferentia-trainium ali..." | Re-trigger Greptile

Comment thread metaflow/plugins/airflow/airflow.py
Comment thread metaflow/plugins/kubernetes/kubernetes_decorator.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 8, 2026

Welcome to Codecov 🎉

Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests.

Thanks for integrating Codecov - We've got you covered ☂️

@emattia emattia changed the title Add trainium parameter to @kubernetes decorator Add trainium, inferentia, and efa parameters to @kubernetes decorator May 4, 2026
@emattia emattia force-pushed the trn-k8s branch 2 times, most recently from 9b816f1 to b1db907 Compare May 10, 2026 21:13
@emattia emattia marked this pull request as ready for review May 10, 2026 21:46
emattia added 7 commits May 14, 2026 07:02
Copy link
Copy Markdown
Collaborator

@saikonen saikonen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No issues with the changes itself, but a question on the overall UX. Is the main goal of this feature added convenience? From what I can tell, everything seems to already be achievable with @kubernetes(tolerations=) if I'm not mistaken.

It also seems a bit of a departure for the @kubernetes decorator to implement provider-specific attributes.

@emattia
Copy link
Copy Markdown
Contributor Author

emattia commented May 14, 2026

No issues with the changes itself, but a question on the overall UX. Is the main goal of this feature added convenience? From what I can tell, everything seems to already be achievable with @kubernetes(tolerations=) if I'm not mistaken.

It also seems a bit of a departure for the @kubernetes decorator to implement provider-specific attributes.

Right, the tolerations is what needs to happen, the @kubernetes part is to make it easy for users who don't care about the backend machine. I agree it is a departure, my logic was to make the minimal change that is in parity with the existing Metaflow API on the @batch side. Ideally users don't have to set tolerations in @kubernetes imo, as many of our end users don't know what that word means. Do you have thoughts on a better approach?

# Validate mutually exclusive: gpu and trainium cannot both be set.
if (
self.attributes["trainium"] is not None
and self.attributes["gpu"] is not None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be greater than zero? trainium and gpu=0 should work right

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants