Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions convert_lora_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,24 @@ def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
return self.transpose(axis0, axis1)

def split(self, split_size: int | Sequence[int], dim: int = 0) -> tuple[LoraTorchTensor, ...]:
shape = self.shape
ndim = len(shape)
if dim < 0:
dim += ndim
if dim == ndim - 1:
A_chunks = self._lora_A.split(split_size, dim=-1)
return tuple(LoraTorchTensor(a, self._lora_B) for a in A_chunks)
elif dim == ndim - 2:
B_chunks = self._lora_B.split(split_size, dim=-2)
return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks)
else:
B_chunks = self._lora_B.split(split_size, dim=dim)
if self._lora_A.shape[dim] == 1:
return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks)
A_chunks = self._lora_A.split(split_size, dim=dim)
return tuple(LoraTorchTensor(a, b) for a, b in zip(A_chunks, B_chunks))

def to(self, *args, **kwargs):
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))

Expand Down Expand Up @@ -230,6 +248,11 @@ def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
)
else:
raise NotImplementedError
elif func is torch.split:
assert len(args) and len(args) >= 2
tensor, split_size = args[0], args[1]
dim = args[2] if len(args) > 2 else kwargs.get("dim", 0)
return tensor.split(split_size, dim=dim)
else:
raise NotImplementedError

Expand Down
46 changes: 39 additions & 7 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,19 +647,30 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_m
return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, const ggml_tensor * op, int nsg, int nxpsg, int r1ptg) {
char base[256];
char name[256];

const ggml_type tsrc0 = op->src[0]->type;
const ggml_type tsrc1 = op->src[1]->type;
const int ne12 = op->src[1]->ne[2];
const int r2 = ne12 / op->src[0]->ne[2];
const int r3 = op->src[1]->ne[3] / op->src[0]->ne[3];

GGML_ASSERT(ne12 <= INT16_MAX && r2 <= INT16_MAX && r3 <= INT16_MAX);

snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, nxpsg, ne12, r2, r3);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2);
ggml_metal_cv_set_int16(cv, (int16_t) r2, FC_MUL_MV + 3);
ggml_metal_cv_set_int16(cv, (int16_t) r3, FC_MUL_MV + 4);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

Expand Down Expand Up @@ -687,15 +698,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta
? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0)
: (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0);

GGML_ASSERT(op->src[1]->ne[2] <= INT16_MAX && op->src[1]->ne[3] <= INT16_MAX);
const int16_t ne12 = (int16_t) op->src[1]->ne[2];
const int16_t ne13 = (int16_t) op->src[1]->ne[3];
const int16_t r2 = (int16_t) (ne12 / op->src[0]->ne[2]);
const int16_t r3 = (int16_t) (ne13 / op->src[0]->ne[3]);

snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
snprintf(name, 256, "%s_bci=%d_bco=%d_ne12=%d_ne13=%d_r2=%d_r3=%d",
base, bc_inp, bc_out, ne12, ne13, r2, r3);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
ggml_metal_cv_set_int16(cv, ne12, FC_MUL_MM + 2);
ggml_metal_cv_set_int16(cv, ne13, FC_MUL_MM + 3);
ggml_metal_cv_set_int16(cv, r2, FC_MUL_MM + 4);
ggml_metal_cv_set_int16(cv, r3, FC_MUL_MM + 5);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

Expand Down Expand Up @@ -877,14 +899,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
}
};

GGML_ASSERT(ne12 <= INT16_MAX && ne13 <= INT16_MAX);
const int16_t r2 = (int16_t) (ne12 / ne02);
const int16_t r3 = (int16_t) (ne13 / ne03);

snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
snprintf(name, 256, "%s_nsg=%d", base, nsg);
snprintf(name, 256, "%s_nsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, ne12, r2, r3);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2);
ggml_metal_cv_set_int16(cv, r2, FC_MUL_MV + 3);
ggml_metal_cv_set_int16(cv, r3, FC_MUL_MV + 4);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

Expand Down Expand Up @@ -1102,6 +1131,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 2);
ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 3);
ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 4);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, const struct ggml_tensor * op, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2120,7 +2120,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
GGML_ABORT("unsupported ne11");
};

auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op, nsg, nxpsg, r1ptg);

ggml_metal_kargs_mul_mv_ext args = {
/*.ne00 =*/ ne00,
Expand Down
Loading
Loading