diff --git a/gemma/peft/_einsum_utils.py b/gemma/peft/_einsum_utils.py index f57dcec3..e825ce53 100644 --- a/gemma/peft/_einsum_utils.py +++ b/gemma/peft/_einsum_utils.py @@ -28,7 +28,7 @@ def get_lora_einsum_str_and_shapes( ) -> tuple[str, _Shape, _Shape]: """Extract the LoRA decomposition from the original einsum parameters. - This function reqrites a einsum string `inputs,weights->outputs` into + This function rewrites an einsum string `inputs,weights->outputs` into `inputs,a,b->outputs`. Args: @@ -66,7 +66,7 @@ def get_lora_einsum_str_and_shapes( lora_einsum_str = f'{inputs},{a_str},{b_str}->{outputs}' - # This assume there's no elipsis in the weights. + # This assumes there's no ellipsis in the weights. weights_str_to_dim = dict(zip(weights, weights_shape)) weights_str_to_dim[rank_dim] = rank a_shape = tuple(weights_str_to_dim[c] for c in a_str) @@ -78,7 +78,6 @@ def get_lora_einsum_str_and_shapes( def _split_einsum_str(einsum_str: str) -> tuple[str, str, str]: """Splits an einsum string into its components.""" - # TODO(epot): Check length def _check_len2(x): if len(x) != 2: raise ValueError(