From b34d4489380f65b8ce19ad1b54714513eaa8bf89 Mon Sep 17 00:00:00 2001 From: dharmateja03 <59060125+dharmateja03@users.noreply.github.com> Date: Tue, 12 May 2026 21:22:16 -0400 Subject: [PATCH] Add Python shell command and arithmetic substitution --- bash-py/supermemory_bash/_parse.py | 393 +++++++++++++++++++++++++++-- bash-py/supermemory_bash/_shell.py | 45 ++-- bash-py/tests/test_shell.py | 44 ++++ 3 files changed, 447 insertions(+), 35 deletions(-) diff --git a/bash-py/supermemory_bash/_parse.py b/bash-py/supermemory_bash/_parse.py index a43fc7b..cac1ea1 100644 --- a/bash-py/supermemory_bash/_parse.py +++ b/bash-py/supermemory_bash/_parse.py @@ -1,31 +1,39 @@ """Bridge between just-bash-py's AST and our shell execution layer.""" from __future__ import annotations -from dataclasses import dataclass, field +from collections.abc import Awaitable, Callable +from dataclasses import dataclass -from ._vendor.just_bash import parse as _jb_parse, ParseException +from ._vendor.just_bash import ParseException +from ._vendor.just_bash import parse as _jb_parse from ._vendor.just_bash.ast.types import ( - ScriptNode, - StatementNode, - PipelineNode, - SimpleCommandNode, - WordNode, - RedirectionNode, + ArithAssignmentNode, + ArithBinaryNode, + ArithExpr, + ArithGroupNode, + ArithmeticExpansionPart, + ArithNestedNode, + ArithNumberNode, + ArithTernaryNode, + ArithUnaryNode, + ArithVariableNode, + AssignDefaultOp, AssignmentNode, - HereDocNode, - LiteralPart, - SingleQuotedPart, + CommandSubstitutionPart, + DefaultValueOp, DoubleQuotedPart, + ErrorIfUnsetOp, EscapedPart, + GlobPart, + HereDocNode, + LiteralPart, ParameterExpansionPart, + RedirectionNode, + ScriptNode, + SingleQuotedPart, TildeExpansionPart, - DefaultValueOp, UseAlternativeOp, - AssignDefaultOp, - ErrorIfUnsetOp, - CommandSubstitutionPart, - ArithmeticExpansionPart, - GlobPart, + WordNode, ) @@ -41,6 +49,9 @@ class Redirect: content: str = "" +CommandSubstitutionRunner = Callable[[ScriptNode], Awaitable[str]] + + def parse_command(cmd: str) -> ScriptNode: try: return _jb_parse(cmd) @@ -119,6 +130,175 @@ def expand_words(words: tuple[WordNode, ...], env: dict[str, str]) -> list[str]: return [expand_word(w, env) for w in words] +async def expand_word_async( + word: WordNode, + env: dict[str, str], + command_runner: CommandSubstitutionRunner, +) -> str: + parts: list[str] = [] + for p in word.parts: + parts.append(await _expand_part_async(p, env, command_runner)) + return "".join(parts) + + +async def _expand_part_async( + part: object, + env: dict[str, str], + command_runner: CommandSubstitutionRunner, +) -> str: + if isinstance(part, LiteralPart): + return part.value + + if isinstance(part, SingleQuotedPart): + return part.value + + if isinstance(part, DoubleQuotedPart): + expanded = [await _expand_part_async(p, env, command_runner) for p in part.parts] + return "".join(expanded) + + if isinstance(part, EscapedPart): + return part.value + + if isinstance(part, ParameterExpansionPart): + return await _expand_parameter_async(part, env, command_runner) + + if isinstance(part, TildeExpansionPart): + return env.get("HOME", "/home/user") + + if isinstance(part, GlobPart): + return part.pattern + + if isinstance(part, CommandSubstitutionPart): + if part.body is None: + return "" + # Bash removes trailing newlines from command substitution output. + return (await command_runner(part.body)).rstrip("\n") + + if isinstance(part, ArithmeticExpansionPart): + return str(eval_arithmetic_expansion(part, env)) + + return "" + + +async def _expand_parameter_async( + part: ParameterExpansionPart, + env: dict[str, str], + command_runner: CommandSubstitutionRunner, +) -> str: + name = part.parameter + if name == "?": + return env.get("?", "0") + val = env.get(name) + op = part.operation + if op is None: + return val if val is not None else "" + if isinstance(op, DefaultValueOp): + default = await expand_word_async(op.word, env, command_runner) if op.word else "" + if op.check_empty: + return val if val else default + return val if val is not None else default + if isinstance(op, UseAlternativeOp): + alt = await expand_word_async(op.word, env, command_runner) if op.word else "" + if op.check_empty: + return alt if val else "" + return alt if val is not None else "" + if isinstance(op, AssignDefaultOp): + default = await expand_word_async(op.word, env, command_runner) if op.word else "" + if (op.check_empty and not val) or (not op.check_empty and val is None): + env[name] = default + return default + return val or "" + if isinstance(op, ErrorIfUnsetOp): + if (op.check_empty and not val) or (not op.check_empty and val is None): + msg = ( + await expand_word_async(op.word, env, command_runner) + if op.word + else f"{name}: parameter null or not set" + ) + raise UnsupportedSyntaxError(msg) + return val or "" + return val if val is not None else "" + + +def _word_has_unquoted_command_substitution(word: WordNode) -> bool: + return any(isinstance(part, CommandSubstitutionPart) for part in word.parts) + + +async def expand_words_async( + words: tuple[WordNode, ...], + env: dict[str, str], + command_runner: CommandSubstitutionRunner, +) -> list[str]: + out: list[str] = [] + for word in words: + expanded = await expand_word_async(word, env, command_runner) + if _word_has_unquoted_command_substitution(word): + out.extend(expanded.split()) + else: + out.append(expanded) + return out + + +async def extract_redirects_async( + redirections: tuple[RedirectionNode, ...], + env: dict[str, str], + command_runner: CommandSubstitutionRunner, +) -> list[Redirect]: + out: list[Redirect] = [] + for r in redirections: + op = r.operator + + if isinstance(r.target, HereDocNode): + content = ( + await expand_word_async(r.target.content, env, command_runner) + if r.target.content + else "" + ) + out.append(Redirect(op="<", fd=0, content=content)) + continue + + target_path = ( + await expand_word_async(r.target, env, command_runner) + if isinstance(r.target, WordNode) + else str(r.target) + ) + + if op in (">", ">|"): + fd = r.fd if r.fd is not None else 1 + out.append(Redirect(op=">", path=target_path, fd=fd)) + elif op == ">>": + fd = r.fd if r.fd is not None else 1 + out.append(Redirect(op=">>", path=target_path, fd=fd)) + elif op == "<": + out.append(Redirect(op="<", path=target_path, fd=0)) + elif op == "<<" or op == "<<-": + out.append(Redirect(op="<", fd=0, content="")) + elif op == "&>" or op == "&>>": + redir_op = ">>" if op == "&>>" else ">" + out.append(Redirect(op=redir_op, path=target_path, fd=1)) + out.append(Redirect(op=redir_op, path=target_path, fd=2)) + elif op == ">&" or op == "<&": + pass + else: + out.append(Redirect(op=">", path=target_path, fd=r.fd or 1)) + + return out + + +async def extract_assignments_async( + assignments: tuple[AssignmentNode, ...], + env: dict[str, str], + command_runner: CommandSubstitutionRunner, +) -> dict[str, str]: + out: dict[str, str] = {} + for a in assignments: + if a.value: + out[a.name] = await expand_word_async(a.value, env, command_runner) + else: + out[a.name] = "" + return out + + def extract_redirects( redirections: tuple[RedirectionNode, ...], env: dict[str, str], @@ -132,7 +312,11 @@ def extract_redirects( out.append(Redirect(op="<", fd=0, content=content)) continue - target_path = expand_word(r.target, env) if isinstance(r.target, WordNode) else str(r.target) + target_path = ( + expand_word(r.target, env) + if isinstance(r.target, WordNode) + else str(r.target) + ) if op in (">", ">|"): fd = r.fd if r.fd is not None else 1 @@ -143,8 +327,7 @@ def extract_redirects( elif op == "<": out.append(Redirect(op="<", path=target_path, fd=0)) elif op == "<<" or op == "<<-": - content = expand_word(r.target.content, env) if hasattr(r.target, "content") and r.target.content else "" - out.append(Redirect(op="<", fd=0, content=content)) + out.append(Redirect(op="<", fd=0, content="")) elif op == "&>" or op == "&>>": # Redirect both stdout and stderr append = op == "&>>" @@ -171,3 +354,173 @@ def extract_assignments( else: out[a.name] = "" return out + + +def eval_arithmetic_expansion(part: ArithmeticExpansionPart, env: dict[str, str]) -> int: + expr = part.expression.expression if part.expression else None + return _eval_arithmetic(expr, env) + + +def _env_int(env: dict[str, str], name: str) -> int: + try: + return int(env.get(name, "0") or "0", 0) + except ValueError: + return 0 + + +def _trunc_div(lhs: int, rhs: int) -> int: + if rhs == 0: + raise UnsupportedSyntaxError("division by zero") + return int(lhs / rhs) + + +def _eval_arithmetic(expr: ArithExpr | None, env: dict[str, str]) -> int: + if expr is None: + return 0 + + if isinstance(expr, ArithNumberNode): + return expr.value + + if isinstance(expr, ArithVariableNode): + return _env_int(env, expr.name) + + if isinstance(expr, ArithGroupNode): + return _eval_arithmetic(expr.expression, env) + + if isinstance(expr, ArithNestedNode): + return _eval_arithmetic(expr.expression, env) + + if isinstance(expr, ArithTernaryNode): + branch = expr.consequent if _eval_arithmetic(expr.condition, env) != 0 else expr.alternate + return _eval_arithmetic(branch, env) + + if isinstance(expr, ArithUnaryNode): + return _eval_arithmetic_unary(expr, env) + + if isinstance(expr, ArithBinaryNode): + return _eval_arithmetic_binary(expr, env) + + if isinstance(expr, ArithAssignmentNode): + return _eval_arithmetic_assignment(expr, env) + + raise UnsupportedSyntaxError(f"unsupported arithmetic expression: {type(expr).__name__}") + + +def _eval_arithmetic_unary(expr: ArithUnaryNode, env: dict[str, str]) -> int: + if isinstance(expr.operand, ArithVariableNode) and expr.operator in ("++", "--"): + old = _env_int(env, expr.operand.name) + new = old + 1 if expr.operator == "++" else old - 1 + env[expr.operand.name] = str(new) + return new if expr.prefix else old + + value = _eval_arithmetic(expr.operand, env) + if expr.operator == "-": + return -value + if expr.operator == "+": + return value + if expr.operator == "!": + return 0 if value else 1 + if expr.operator == "~": + return ~value + raise UnsupportedSyntaxError(f"unsupported arithmetic operator: {expr.operator}") + + +def _eval_arithmetic_binary(expr: ArithBinaryNode, env: dict[str, str]) -> int: + op = expr.operator + + if op == "&&": + return ( + 1 + if _eval_arithmetic(expr.left, env) != 0 + and _eval_arithmetic(expr.right, env) != 0 + else 0 + ) + if op == "||": + return ( + 1 + if _eval_arithmetic(expr.left, env) != 0 + or _eval_arithmetic(expr.right, env) != 0 + else 0 + ) + + lhs = _eval_arithmetic(expr.left, env) + rhs = _eval_arithmetic(expr.right, env) + + if op == "+": + return lhs + rhs + if op == "-": + return lhs - rhs + if op == "*": + return lhs * rhs + if op == "/": + return _trunc_div(lhs, rhs) + if op == "%": + if rhs == 0: + raise UnsupportedSyntaxError("division by zero") + return lhs % rhs + if op == "**": + return int(pow(lhs, rhs)) + if op == "<<": + return lhs << rhs + if op == ">>": + return lhs >> rhs + if op == "<": + return 1 if lhs < rhs else 0 + if op == "<=": + return 1 if lhs <= rhs else 0 + if op == ">": + return 1 if lhs > rhs else 0 + if op == ">=": + return 1 if lhs >= rhs else 0 + if op == "==": + return 1 if lhs == rhs else 0 + if op == "!=": + return 1 if lhs != rhs else 0 + if op == "&": + return lhs & rhs + if op == "|": + return lhs | rhs + if op == "^": + return lhs ^ rhs + if op == ",": + return rhs + + raise UnsupportedSyntaxError(f"unsupported arithmetic operator: {op}") + + +def _eval_arithmetic_assignment(expr: ArithAssignmentNode, env: dict[str, str]) -> int: + if expr.subscript is not None: + raise UnsupportedSyntaxError("arithmetic array assignment is not supported") + + current = _env_int(env, expr.variable) + value = _eval_arithmetic(expr.value, env) + + if expr.operator == "=": + result = value + elif expr.operator == "+=": + result = current + value + elif expr.operator == "-=": + result = current - value + elif expr.operator == "*=": + result = current * value + elif expr.operator == "/=": + result = _trunc_div(current, value) + elif expr.operator == "%=": + if value == 0: + raise UnsupportedSyntaxError("division by zero") + result = current % value + elif expr.operator == "<<=": + result = current << value + elif expr.operator == ">>=": + result = current >> value + elif expr.operator == "&=": + result = current & value + elif expr.operator == "|=": + result = current | value + elif expr.operator == "^=": + result = current ^ value + else: + raise UnsupportedSyntaxError(f"unsupported arithmetic operator: {expr.operator}") + + env[expr.variable] = str(result) + return result diff --git a/bash-py/supermemory_bash/_shell.py b/bash-py/supermemory_bash/_shell.py index 8b32681..12f0009 100644 --- a/bash-py/supermemory_bash/_shell.py +++ b/bash-py/supermemory_bash/_shell.py @@ -2,24 +2,23 @@ import fnmatch import re -from dataclasses import dataclass, field -from typing import Callable, Awaitable +from collections.abc import Awaitable, Callable +from dataclasses import dataclass from ._errors import FsError from ._parse import ( - Redirect, UnsupportedSyntaxError, + expand_word_async, + expand_words_async, + extract_assignments_async, + extract_redirects_async, parse_command, - expand_word, - expand_words, - extract_redirects, - extract_assignments, ) from ._vendor.just_bash.ast.types import ( - SimpleCommandNode, PipelineNode, - StatementNode, ScriptNode, + SimpleCommandNode, + StatementNode, ) from ._volume import SupermemoryVolume @@ -156,7 +155,11 @@ async def _exec_command(self, node: object, stdin: str = "") -> ExecResult: exit_code=2, ) - assigns = extract_assignments(node.assignments, self.env) + assigns = await extract_assignments_async( + node.assignments, + self.env, + self._run_command_substitution, + ) for k, v in assigns.items(): self.env[k] = v @@ -164,9 +167,13 @@ async def _exec_command(self, node: object, stdin: str = "") -> ExecResult: return ExecResult() try: - name = expand_word(node.name, self.env) - args = expand_words(node.args, self.env) - redirects = extract_redirects(node.redirections, self.env) + name = await expand_word_async(node.name, self.env, self._run_command_substitution) + args = await expand_words_async(node.args, self.env, self._run_command_substitution) + redirects = await extract_redirects_async( + node.redirections, + self.env, + self._run_command_substitution, + ) except UnsupportedSyntaxError as e: return ExecResult(stderr=f"{e}\n", exit_code=2) @@ -195,7 +202,7 @@ async def _exec_command(self, node: object, stdin: str = "") -> ExecResult: if redir.content or redir.op == "<": continue - path = self._resolve(redir.path) if redir.path else None + path: str | None = self._resolve(redir.path) if redir.path else None if path == "/dev/null": if redir.fd == 1: @@ -228,6 +235,10 @@ async def _exec_command(self, node: object, stdin: str = "") -> ExecResult: return result + async def _run_command_substitution(self, script: ScriptNode) -> str: + result = await self._exec_script(script) + return result.stdout + # ------------------------------------------------------------------ # Built-in commands # ------------------------------------------------------------------ @@ -835,7 +846,11 @@ async def _cmd_find(self, args: list[str], stdin: str) -> ExecResult: prefix = base_path if base_path.endswith("/") else f"{base_path}/" summaries = await self.volume.list_by_prefix(prefix) - results: list[str] = [base_path] + results: list[str] = [] + if type_filter != "f": + base_name = base_path.rsplit("/", 1)[-1] or "/" + if name_pattern is None or fnmatch.fnmatch(base_name, name_pattern): + results.append(base_path) seen_dirs: set[str] = set() for s in summaries: rest = s.filepath[len(prefix):] diff --git a/bash-py/tests/test_shell.py b/bash-py/tests/test_shell.py index 56db5d4..bde33c9 100644 --- a/bash-py/tests/test_shell.py +++ b/bash-py/tests/test_shell.py @@ -608,6 +608,50 @@ async def test_grep_c_counts_each_match_once(shell_and_vol): assert r.stdout.strip() == "2", r.stdout +@pytest.mark.asyncio +async def test_command_substitution_expands_stdout(shell_and_vol): + shell, vol = shell_and_vol + await vol.add_doc("/notes/a.txt", "alpha") + r = await shell.exec("cat $(find /notes -type f -name '*.txt')") + assert r.exit_code == 0, r + assert r.stdout == "alpha" + + +@pytest.mark.asyncio +async def test_command_substitution_splits_unquoted_lines(shell_and_vol): + shell, vol = shell_and_vol + await vol.add_doc("/notes/a.txt", "alpha\n") + await vol.add_doc("/notes/b.txt", "beta\n") + r = await shell.exec("cat $(find /notes -type f -name '*.txt')") + assert r.exit_code == 0, r + assert r.stdout == "alpha\nbeta\n" + + +@pytest.mark.asyncio +async def test_backtick_command_substitution(shell_and_vol): + shell, vol = shell_and_vol + await vol.add_doc("/notes/a.txt", "alpha") + r = await shell.exec("cat `find /notes -type f -name '*.txt'`") + assert r.exit_code == 0, r + assert r.stdout == "alpha" + + +@pytest.mark.asyncio +async def test_arithmetic_expansion(shell_and_vol): + shell, vol = shell_and_vol + r = await shell.exec("echo $((1 + 2 * 3))") + assert r.exit_code == 0, r + assert r.stdout == "7\n" + + +@pytest.mark.asyncio +async def test_arithmetic_expansion_uses_and_updates_variables(shell_and_vol): + shell, vol = shell_and_vol + r = await shell.exec("n=4 echo $((n += 3)) $((n * 2))") + assert r.exit_code == 0, r + assert r.stdout == "7 14\n" + + # --- Nested-synthetic-dir bugs (from v6 PR review) --- @pytest.mark.asyncio