Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion comfy_api_nodes/apis/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,21 @@ class ImageEditRequest(BaseModel):
class VideoGenerationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
image: InputUrlObject | None = Field(...)
image: InputUrlObject | None = Field(None)
reference_images: list[InputUrlObject] | None = Field(None)
duration: int = Field(...)
aspect_ratio: str | None = Field(...)
resolution: str = Field(...)
seed: int = Field(...)


class VideoExtensionRequest(BaseModel):
prompt: str = Field(...)
video: InputUrlObject = Field(...)
duration: int = Field(default=6)
model: str | None = Field(default=None)


class VideoEditRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
Expand Down
251 changes: 251 additions & 0 deletions comfy_api_nodes/nodes_grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ImageGenerationResponse,
InputUrlObject,
VideoEditRequest,
VideoExtensionRequest,
VideoGenerationRequest,
VideoGenerationResponse,
VideoStatusResponse,
Expand All @@ -21,6 +22,7 @@
poll_op,
sync_op,
tensor_to_base64_string,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
validate_video_duration,
Expand All @@ -33,6 +35,13 @@ def _extract_grok_price(response) -> float | None:
return None


def _extract_grok_video_price(response) -> float | None:
price = _extract_grok_price(response)
if price is not None:
return price * 1.43
return None


class GrokImageNode(IO.ComfyNode):

@classmethod
Expand Down Expand Up @@ -354,6 +363,8 @@ async def execute(
seed: int,
image: Input.Image | None = None,
) -> IO.NodeOutput:
if model == "grok-imagine-video-beta":
model = "grok-imagine-video"
image_url = None
if image is not None:
if get_number_of_images(image) != 1:
Expand Down Expand Up @@ -462,14 +473,254 @@ async def execute(
return IO.NodeOutput(await download_url_to_video_output(response.video.url))


class GrokVideoReferenceNode(IO.ComfyNode):

@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokVideoReferenceNode",
display_name="Grok Reference-to-Video",
category="api node/video/Grok",
description="Generate video guided by reference images as style and content references.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text description of the desired video.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"grok-imagine-video",
[
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="reference_",
min=1,
max=7,
),
tooltip="Up to 7 reference images to guide the video generation.",
),
IO.Combo.Input(
"resolution",
options=["480p", "720p"],
tooltip="The resolution of the output video.",
),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
tooltip="The aspect ratio of the output video.",
),
IO.Int.Input(
"duration",
default=6,
min=2,
max=10,
step=1,
tooltip="The duration of the output video in seconds.",
display_mode=IO.NumberDisplay.slider,
),
],
),
],
tooltip="The model to use for video generation.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model.duration", "model.resolution"],
input_groups=["model.reference_images"],
),
expr="""
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$refs := inputGroups["model.reference_images"];
$rate := $res = "720p" ? 0.07 : 0.05;
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
{"type":"usd","usd": $price}
)
""",
),
)

@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
ref_image_urls = await upload_images_to_comfyapi(
cls,
list(model["reference_images"].values()),
mime_type="image/png",
wait_label="Uploading base images",
max_images=7,
)
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
data=VideoGenerationRequest(
model=model["model"],
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
prompt=prompt,
resolution=model["resolution"],
duration=model["duration"],
aspect_ratio=model["aspect_ratio"],
seed=seed,
),
response_model=VideoGenerationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_video_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))


class GrokVideoExtendNode(IO.ComfyNode):

@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokVideoExtendNode",
display_name="Grok Video Extend",
category="api node/video/Grok",
description="Extend an existing video with a seamless continuation based on a text prompt.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text description of what should happen next in the video.",
),
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"grok-imagine-video",
[
IO.Int.Input(
"duration",
default=8,
min=2,
max=10,
step=1,
tooltip="Length of the extension in seconds.",
display_mode=IO.NumberDisplay.slider,
),
],
),
],
tooltip="The model to use for video extension.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
expr="""
(
$dur := $lookup(widgets, "model.duration");
{
"type": "range_usd",
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
"max_usd": (0.15 + 0.05 * $dur) * 1.43
}
)
""",
),
)

@classmethod
async def execute(
cls,
prompt: str,
video: Input.Video,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
validate_video_duration(video, min_duration=2, max_duration=15)
video_size = get_fs_object_size(video.get_stream_source())
if video_size > 50 * 1024 * 1024:
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
data=VideoExtensionRequest(
prompt=prompt,
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
duration=model["duration"],
),
response_model=VideoGenerationResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_video_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))


class GrokExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
GrokImageNode,
GrokImageEditNode,
GrokVideoNode,
GrokVideoReferenceNode,
GrokVideoEditNode,
GrokVideoExtendNode,
]


Expand Down
Loading