Skip to content

Attention只能使用flash_attn方式计算吗? #12

@goldlee

Description

@goldlee

def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
backend: str = "flash_attn",
*,
causal: bool = False,
softmax_scale: float = None,
attn_kwargs: dict = None,
):
"""
Args:
q (torch.Tensor): Query tensor of shape [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor): Key tensor of shape [batch_size, seq_len, num_heads, head_dim]
v (torch.Tensor): Value tensor of shape [batch_size, seq_len, num_heads
"""
if "flash_attn" == get_preferred_attention_backend():
assert backend in ["flash_attn"], f"Unsupported attention backend: {backend}"
assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Input tensors must be 4D"
batch_size = q.shape[0]

    cu_seqlens_q = attn_kwargs['cu_seqlens_q']
    cu_seqlens_kv = attn_kwargs['cu_seqlens_kv']
    max_seqlen_q = attn_kwargs['max_seqlen_q']
    max_seqlen_kv = attn_kwargs['max_seqlen_kv']
    x = flash_attn_varlen_func(
        q.view(q.shape[0] * q.shape[1], *q.shape[2:]),
        k.view(k.shape[0] * k.shape[1], *k.shape[2:]),
        v.view(v.shape[0] * v.shape[1], *v.shape[2:]),
        cu_seqlens_q,
        cu_seqlens_kv,
        max_seqlen_q,
        max_seqlen_kv,
    )
    output = x.view(
        batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
    )
else:
    from torch.nn.functional import scaled_dot_product_attention

    # 转换维度: [batch, seq_len, heads, dim] -> [batch, heads, seq_len, dim]
    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    output = scaled_dot_product_attention(
        q, k, v,
        is_causal=causal,
        scale=softmax_scale
    )

    # 转回 [batch, seq_len, heads, dim]
    output = output.transpose(1, 2)

return output

我添加了torch原生的Attention计算,但是使用inference.py进行推理时,结果完全不对
我使用了https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers里面的input.png,和output1_predicted.png完全对不上,是哪里有问题吗?参数如下:
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Run local inference without FastAPI.')
parser.add_argument('--ckpt-root', default="JoyAI-Image-Edit", help='Checkpoint root.')
parser.add_argument('--prompt', default="Remove the construction structure from the top of the crane.", help='Edit prompt or T2I prompt.')
parser.add_argument('--image', default="test_images/input.png", help='Optional input image path for image editing.')
parser.add_argument('--output', default='example.png', help='Output image path.')
parser.add_argument('--height', type=int, default=1024, help='Only used for text-to-image inference.')
parser.add_argument('--width', type=int, default=1024, help='Only used for text-to-image inference.')
parser.add_argument('--steps', type=int, default=30)
parser.add_argument('--guidance-scale', type=float, default=4.0)
parser.add_argument('--seed', type=int, default=123)
parser.add_argument('--neg-prompt', default='')
parser.add_argument('--basesize', type=int, default=1024, help='Resize bucket base size for image editing inputs.')
parser.add_argument('--rewrite-prompt', action='store_true')
parser.add_argument('--config', help='Optional config path. Defaults to /infer_config.py.')
parser.add_argument('--rewrite-model', default='gpt-5')
parser.add_argument('--hsdp-shard-dim', type=int, help='Override config hsdp_shard_dim for multi-GPU FSDP inference.')
return parser.parse_args()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions