diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index 144d0883..00000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,189 +0,0 @@ -params: ¶ms - parameters: - version: - description: Python docker image version - default: 3.9.16 - type: string - -job_defaults: &job_defaults - <<: *params - executor: - name: python - version: << parameters.version >> - -version: 2.1 - -executors: - python: - <<: *params - docker: - - image: cimg/python:<< parameters.version >> - -jobs: - tests: - description: Run test suite for a specific python version - <<: *job_defaults - steps: - - checkout - - restore_cache: &restore_cache - keys: - - sshtunnel-py<< parameters.version >>-{{ checksum "sshtunnel.py" }}-{{ checksum "tests/requirements.txt" }}-0 - - run: &install - name: Install sshtunnel and build&test dependencies - command: | - python --version - pipenv --version - pip --version - pipenv install -e . - pipenv install --dev -r tests/requirements.txt - cat Pipfile.lock - environment: - - PIPENV_VENV_IN_PROJECT: 1 - - save_cache: &save_cache - key: sshtunnel-py<< parameters.version >>-{{ checksum "sshtunnel.py" }}-{{ checksum "tests/requirements.txt" }}-0 - paths: - - .venv/ - - run: - name: Run test suite - command: >- - pipenv run py.test tests - --showlocals - --durations=10 - -n4 - -W ignore::DeprecationWarning - --cov sshtunnel - --cov-report=html:test_results/coverage.html - --cov-report=term - --junit-xml=test_results/report.xml - - run: - name: Coveralls - command: pipenv run coveralls - - store_test_results: - path: test_results - - store_artifacts: - path: test_results - - docs: - description: Produce documentation from source - <<: *job_defaults - steps: - - checkout - - restore_cache: *restore_cache - - run: *install - - save_cache: *save_cache - - run: - name: Installing documentation dependencies - command: pipenv install --dev -r docs/requirements.txt - - run: - name: Build documentation - command: pipenv run sphinx-build -WavE -b html docs _build/html - - store_artifacts: - path: _build/html - destination: sshtunnel-docs - - syntax: - description: Run syntax validation tests - <<: *job_defaults - steps: - - checkout - - restore_cache: *restore_cache - - run: *install - - save_cache: *save_cache - - run: - name: Installing syntax checks dependencies - command: pipenv install --dev -r tests/requirements-syntax.txt - - run: - name: checking MANIFEST.in - command: pipenv run check-manifest --ignore tox.ini,tests*,*.yml - - run: - name: checking RST syntax - command: | - pipenv run python setup.py sdist - pipenv run twine check dist/* - - run: - name: checking PEP8 compliancy - command: pipenv run flake8 --exclude .venv,build,docs,e2e_tests --max-complexity 10 --ignore=W504 - - run: - name: checking CLI help - command: pipenv run bashtest README.rst - - testdeploy: - description: Build and upload artifacts to Test PyPI - <<: *job_defaults - steps: - - checkout - - restore_cache: *restore_cache - - run: - name: Build artifact - command: | - pipenv run python setup.py bdist_egg bdist_wheel sdist - - run: - name: Check artifacts - command: pipenv run twine check dist/* - - store_artifacts: - path: dist/ - - run: - name: Upload to TestPyPI - command: >- - pipenv run twine upload - --repository testpypi - --username __token__ - --password $TESTPYPI_TOKEN - --skip-existing - dist/* - - deploy: - description: Build and upload artifacts to PyPI - <<: *job_defaults - steps: - - checkout - - restore_cache: *restore_cache - - run: - name: Build artifact - command: | - pipenv run python setup.py bdist_egg bdist_wheel sdist - - run: - name: Upload to PyPI - command: >- - pipenv run twine upload - --username __token__ - --password $PYPI_TOKEN - --skip-existing - dist/* - -workflows: - syntax_and_docs: - jobs: - - syntax - - docs - - test_and_deploy: - jobs: - - tests: - matrix: - parameters: - version: - - "2.7" - - "3.4" - - "3.5" - - "3.6" - - "3.7" - - "3.8" - - testdeploy: - requires: - - tests - - - hold: - type: approval - requires: - - testdeploy - filters: - branches: - only: master - - - deploy: - requires: - - hold - filters: - branches: - only: master diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..69634622 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,106 @@ +name: CI + +on: + push: + branches: [main, master] + pull_request: + branches: [main, master] + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}${{ github.ref == 'refs/heads/master' && format('-{0}', github.sha) || '' }} + cancel-in-progress: ${{ github.ref != 'refs/heads/master' && github.ref != 'refs/heads/main' }} + +jobs: + matrix: + runs-on: ubuntu-latest + timeout-minutes: 5 + outputs: + python-versions: ${{ steps.parse.outputs.python-versions }} + paramiko-versions: ${{ steps.parse.outputs.paramiko-versions }} + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + sparse-checkout: noxfile.py + + - name: Install uv + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + + - id: parse + name: Extract matrix from noxfile.py + run: | + sessions=$(uvx --from 'nox[uv]' nox --list --json 2>/dev/null) + echo "python-versions=$(echo "$sessions" | jq -c '[.[] | select(.name == "tests") | .python] | unique')" >> "$GITHUB_OUTPUT" + echo "paramiko-versions=$(echo "$sessions" | jq -c '[.[] | select(.name == "tests") | .call_spec.paramiko] | unique')" >> "$GITHUB_OUTPUT" + + test: + needs: matrix + runs-on: ubuntu-latest + timeout-minutes: 20 + strategy: + fail-fast: false + matrix: + python-version: ${{ fromJSON(needs.matrix.outputs.python-versions) }} + paramiko-version: ${{ fromJSON(needs.matrix.outputs.paramiko-versions) }} + + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Install uv + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + enable-cache: true + + - name: Set up Python ${{ matrix.python-version }} + run: uv python install ${{ matrix.python-version }} + + - name: Run tests (paramiko ${{ matrix.paramiko-version }}) + run: > + uvx --from 'nox[uv]' nox + -s "tests-${{ matrix.python-version }}(paramiko='${{ matrix.paramiko-version }}')" + + e2e: + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Install uv + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + enable-cache: true + + - name: Set up Python + run: uv python install 3.13 + + - name: Run e2e tests + run: uvx --from 'nox[uv]' nox -s e2e + + lint: + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Install uv + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + enable-cache: true + + - name: Set up Python + run: uv python install 3.12 + + - name: Run ruff check + run: uvx ruff check . + + - name: Run ruff format check + run: uvx ruff format --check . diff --git a/.github/workflows/database.yml b/.github/workflows/database.yml deleted file mode 100644 index 2f0e5d9d..00000000 --- a/.github/workflows/database.yml +++ /dev/null @@ -1,73 +0,0 @@ -name: Test tunnel for databases connection - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -jobs: - build: - - runs-on: ubuntu-20.04 - - strategy: - matrix: - python: [ python3, python2 ] - - steps: - - uses: actions/checkout@v2 - - - name: Docker compose up databases and ssh-server - run: | - # openssh-server trying to change file permissions to 0600 and we want to do it in /tmp directory - cp -r ./e2e_tests/ssh-server-config /tmp/ssh - sed -i "s#./ssh-server-config#/tmp/ssh#g" ./e2e_tests/docker-compose.yaml - chmod 600 ./e2e_tests/ssh-server-config/ssh_host_rsa_key - cd e2e_tests && docker-compose up -d - - - name: Install dependencies - run: | - id - uname -a - lsb_release -a - ${{ matrix.python }} -V - curl https://bootstrap.pypa.io/pip/2.7/get-pip.py -o get-pip.py - ${{ matrix.python }} get-pip.py - ${{ matrix.python }} -m pip install --upgrade pip - ${{ matrix.python }} -m pip install . - ${{ matrix.python }} -m pip install psycopg2-binary>=2.9.6 pymysql>=1.0.3 pymongo>=4.3.3 - ${{ matrix.python }} -m pip install --upgrade pyopenssl - - ssh -o "StrictHostKeyChecking=no" linuxserver@127.0.0.1 -p 2223 -i ./e2e_tests/ssh-server-config/ssh_host_rsa_key -vvvvv "uname -a" - - # cd e2e_tests && docker-compose logs ssh; cd .. - # cd e2e_tests && docker-compose exec ssh cat /config/logs/openssh/current; cd .. - # docker exec openssh-server tail -f /config/logs/openssh/current - - - name: Run db tests ${{ matrix.python }} - run: ${{ matrix.python }} e2e_tests/run_docker_e2e_db_tests.py - - - name: Run hungs tests ${{ matrix.python }} - run: timeout 10s ${{ matrix.python }} e2e_tests/run_docker_e2e_hangs_tests.py - - - name: Collect openssh-server logs from docker container - if: failure() - run: docker exec openssh-server cat /config/logs/openssh/current > openssh-server.log - - - name: Collect docker stdout logs - if: failure() - uses: jwalton/gh-docker-logs@v1 - with: - dest: './docker-logs' - - - name: Upload log artifact on failure - uses: actions/upload-artifact@v3 - if: failure() - with: - name: logs - path: | - e2e_tests/*.log - ./docker-logs - *.log - retention-days: 30 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..3d1f1f0b --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,75 @@ +name: Publish to PyPI + +on: + push: + tags: + - '[0-9]*.[0-9]*.[0-9]*' + +permissions: {} + +jobs: + build: + name: Build package + runs-on: ubuntu-latest + timeout-minutes: 20 + permissions: + contents: read + steps: + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Install uv + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + enable-cache: true + + - name: Set up Python + run: uv python install 3.12 + + - name: Build package + run: uv build + + - name: Smoke test wheel + run: | + uv run --isolated --no-project --with dist/*.whl python -c " + import sshtunnel + version = sshtunnel.__version__ + expected = '${{ github.ref_name }}' + print(f'Built version: {version}') + assert version == expected, f'Version mismatch: built {version!r}, expected {expected!r} from tag' + " + + - name: Upload distributions + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + with: + name: dist + path: dist/* + if-no-files-found: error + retention-days: 7 + + publish: + name: Publish to PyPI + runs-on: ubuntu-latest + timeout-minutes: 10 + needs: build + environment: pypi + permissions: + id-token: write + contents: read + steps: + - name: Install uv + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + enable-cache: true + + - name: Download distributions + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + name: dist + path: dist + + - name: Publish to PyPI + run: uv publish --trusted-publishing always dist/* diff --git a/.gitignore b/.gitignore index 03151b98..dc4d8cd5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,29 +1,32 @@ -# sshtunnel -.python-version -.venv -.idea -.env - -# general things to ignore -build/ -dist/ -*.egg-info/ -*.egg*/ +# Python *.py[cod] __pycache__/ *.so -.*cache*/ -#*~ +*.egg-info/ +*.egg*/ +build/ +dist/ + +# Virtual environments +.venv/ +.python-version + +# IDE +.idea/ +.vscode/ -# due to using tox and pytest -.tox -.cache +# Environment +.env + +# Testing .coverage* -*cov* +htmlcov/ pytestdebug.log +pytest.xml +.pytest_cache/ -# due to sphinx -docs/_build/ +# uv +uv.lock -# Pipfile -Pipfile* +# Docs +docs/_build/ diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 6b7ecac9..00000000 --- a/MANIFEST.in +++ /dev/null @@ -1,16 +0,0 @@ -# Include the data files recursive-include data * - -include LICENSE -include *.rst -include docs/conf.py -include docs/Makefile -include docs/*.rst -include docs/*.txt -include tests/* -include e2e_tests/* -include e2e_tests/ssh-server-config/* -exclude .github/* -exclude .circleci/* -exclude *.pyc -exclude __pycache__ -exclude Pipfile* diff --git a/README.rst b/README.rst index 7400816e..86d41eb6 100644 --- a/README.rst +++ b/README.rst @@ -1,52 +1,40 @@ -|CircleCI| |AppVeyor| |readthedocs| |coveralls| |version| +|CI| |pyversions| |version| |license| -|pyversions| |license| +``deepnote-sshtunnel`` -- Pure python SSH tunnels +================================================== -**Author**: `Pahaz`_ +This is a `Deepnote `_ fork of `pahaz/sshtunnel`_ with +the following changes: -**Repo**: https://github.com/pahaz/sshtunnel/ +- **Python 3.10+** only (dropped Python 2 and older 3.x support) +- **paramiko 3, 4, and 5** compatibility (removed deprecated DSA/DSSKey support) +- Modern packaging with ``pyproject.toml``, ``hatchling``, and ``uv`` +- GitHub Actions CI and trusted PyPI publishing -Inspired by https://github.com/jmagnusson/bgtunnel, which doesn't work on -Windows. +**Original author**: `Pahaz`_ -See also: https://github.com/paramiko/paramiko/blob/master/demos/forward.py +**Upstream repo**: https://github.com/pahaz/sshtunnel/ Requirements -------------- +------------ -* `paramiko`_ +* `paramiko`_ >= 3.4 +* Python >= 3.10 Installation ============ -`sshtunnel`_ is on PyPI, so simply run: +``deepnote-sshtunnel`` is on PyPI, so simply run:: -:: + pip install deepnote-sshtunnel - pip install sshtunnel +or:: -or :: + uv add deepnote-sshtunnel - easy_install sshtunnel +The import name remains ``sshtunnel`` for drop-in compatibility:: -or :: - - conda install -c conda-forge sshtunnel - -to have it installed in your environment. - -For installing from source, clone the -`repo `_ and run:: - - python setup.py install - -Testing the package -------------------- - -In order to run the tests you first need -`tox `_ and run:: - - python setup.py test + from sshtunnel import SSHTunnelForwarder Usage scenarios =============== @@ -213,77 +201,13 @@ time. ssh.exec_command(...) -CLI usage -========= - -:: - - $ sshtunnel --help - usage: sshtunnel [-h] [-U SSH_USERNAME] [-p SSH_PORT] [-P SSH_PASSWORD] -R - IP:PORT [IP:PORT ...] [-L [IP:PORT ...]] [-k SSH_HOST_KEY] - [-K KEY_FILE] [-S KEY_PASSWORD] [-t] [-v] [-V] [-x IP:PORT] - [-c SSH_CONFIG_FILE] [-z] [-n] [-d [FOLDER ...]] - ssh_address - - Pure python ssh tunnel utils - Version 0.4.0 - - positional arguments: - ssh_address SSH server IP address (GW for SSH tunnels) - set with "-- ssh_address" if immediately after -R or -L - - options: - -h, --help show this help message and exit - -U SSH_USERNAME, --username SSH_USERNAME - SSH server account username - -p SSH_PORT, --server_port SSH_PORT - SSH server TCP port (default: 22) - -P SSH_PASSWORD, --password SSH_PASSWORD - SSH server account password - -R IP:PORT [IP:PORT ...], --remote_bind_address IP:PORT [IP:PORT ...] - Remote bind address sequence: ip_1:port_1 ip_2:port_2 ... ip_n:port_n - Equivalent to ssh -Lxxxx:IP_ADDRESS:PORT - If port is omitted, defaults to 22. - Example: -R 10.10.10.10: 10.10.10.10:5900 - -L [IP:PORT ...], --local_bind_address [IP:PORT ...] - Local bind address sequence: ip_1:port_1 ip_2:port_2 ... ip_n:port_n - Elements may also be valid UNIX socket domains: - /tmp/foo.sock /tmp/bar.sock ... /tmp/baz.sock - Equivalent to ssh -LPORT:xxxxxxxxx:xxxx, being the local IP address optional. - By default it will listen in all interfaces (0.0.0.0) and choose a random port. - Example: -L :40000 - -k SSH_HOST_KEY, --ssh_host_key SSH_HOST_KEY - Gateway's host key - -K KEY_FILE, --private_key_file KEY_FILE - RSA/DSS/ECDSA private key file - -S KEY_PASSWORD, --private_key_password KEY_PASSWORD - RSA/DSS/ECDSA private key password - -t, --threaded Allow concurrent connections to each tunnel - -v, --verbose Increase output verbosity (default: ERROR) - -V, --version Show version number and quit - -x IP:PORT, --proxy IP:PORT - IP and port of SSH proxy to destination - -c SSH_CONFIG_FILE, --config SSH_CONFIG_FILE - SSH configuration file, defaults to ~/.ssh/config - -z, --compress Request server for compression over SSH transport - -n, --noagent Disable looking for keys from an SSH agent - -d [FOLDER ...], --host_pkey_directories [FOLDER ...] - List of directories where SSH pkeys (in the format `id_*`) may be found - .. _Pahaz: https://github.com/pahaz -.. _sshtunnel: https://pypi.python.org/pypi/sshtunnel +.. _pahaz/sshtunnel: https://github.com/pahaz/sshtunnel .. _paramiko: http://www.paramiko.org/ -.. |CircleCI| image:: https://circleci.com/gh/pahaz/sshtunnel.svg?style=svg - :target: https://circleci.com/gh/pahaz/sshtunnel -.. |AppVeyor| image:: https://ci.appveyor.com/api/projects/status/oxg1vx2ycmnw3xr9?svg=true&passingText=Windows%20-%20OK&failingText=Windows%20-%20Fail - :target: https://ci.appveyor.com/project/pahaz/sshtunnel -.. |readthedocs| image:: https://readthedocs.org/projects/sshtunnel/badge/?version=latest - :target: http://sshtunnel.readthedocs.io/en/latest/?badge=latest - :alt: Documentation Status -.. |coveralls| image:: https://coveralls.io/repos/github/pahaz/sshtunnel/badge.svg?branch=master - :target: https://coveralls.io/github/pahaz/sshtunnel?branch=master -.. |pyversions| image:: https://img.shields.io/pypi/pyversions/sshtunnel.svg -.. |version| image:: https://img.shields.io/pypi/v/sshtunnel.svg - :target: `sshtunnel`_ -.. |license| image:: https://img.shields.io/pypi/l/sshtunnel.svg - :target: https://github.com/pahaz/sshtunnel/blob/master/LICENSE +.. |CI| image:: https://github.com/deepnote/sshtunnel/actions/workflows/ci.yml/badge.svg + :target: https://github.com/deepnote/sshtunnel/actions/workflows/ci.yml +.. |pyversions| image:: https://img.shields.io/pypi/pyversions/deepnote-sshtunnel.svg +.. |version| image:: https://img.shields.io/pypi/v/deepnote-sshtunnel.svg + :target: https://pypi.org/project/deepnote-sshtunnel/ +.. |license| image:: https://img.shields.io/pypi/l/deepnote-sshtunnel.svg + :target: https://github.com/deepnote/sshtunnel/blob/main/LICENSE diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index ad754699..00000000 --- a/appveyor.yml +++ /dev/null @@ -1,59 +0,0 @@ -platform: x64 - -environment: - - matrix: - - PYTHON: "C:\\Python27" - PYTHON_VERSION: "2.7.x" - PYTHON_ARCH: "32" - - - PYTHON: "C:\\Python27-x64" - PYTHON_VERSION: "2.7.x" - PYTHON_ARCH: "64" - - - PYTHON: "C:\\Python35" - PYTHON_VERSION: "3.5.x" - PYTHON_ARCH: "32" - - - PYTHON: "C:\\Python35-x64" - PYTHON_VERSION: "3.5.x" - PYTHON_ARCH: "64" - - - PYTHON: "C:\\Python36" - PYTHON_VERSION: "3.6.x" - PYTHON_ARCH: "32" - - - PYTHON: "C:\\Python36-x64" - PYTHON_VERSION: "3.6.x" - PYTHON_ARCH: "64" - - - PYTHON: "C:\\Python37" - PYTHON_VERSION: "3.7.x" - PYTHON_ARCH: "32" - - - PYTHON: "C:\\Python37-x64" - PYTHON_VERSION: "3.7.x" - PYTHON_ARCH: "64" - - - PYTHON: "C:\\Python38" - PYTHON_VERSION: "3.8.x" - PYTHON_ARCH: "32" - - - PYTHON: "C:\\Python38-x64" - PYTHON_VERSION: "3.8.x" - PYTHON_ARCH: "64" -init: - - "ECHO %PYTHON% %PYTHON_VERSION% %PYTHON_ARCH%" - -install: - - set "PATH=%PYTHON%;%PYTHON%\\Scripts;%PYTHON%\\Tools\\Scripts;%PATH%" - - python -m pip install --upgrade pip - - pip install paramiko - - pip install mock pytest pytest-cov pytest-xdist - -build: off - -test_script: - - python setup.py install - - py.test --showlocals --durations=10 -n4 tests - diff --git a/changelog.rst b/changelog.rst index daf3808a..d4c1c914 100644 --- a/changelog.rst +++ b/changelog.rst @@ -22,6 +22,17 @@ CONTRIBUTORS CHANGELOG ========= +- v.1.0.0 — Deepnote fork (`Deepnote`_) + + Fork as ``deepnote-sshtunnel`` for Paramiko v3/v4/v5 compatibility + + Drop Python 2 and Python < 3.10 support + + Remove DSA key support (``paramiko.DSSKey`` removed in Paramiko 4) + + Add Ed25519 host key support + + Modernize packaging: ``hatchling`` + ``uv-dynamic-versioning`` + + Replace tox/CircleCI/AppVeyor with ``nox`` + GitHub Actions + + Add trusted PyPI publishing via OIDC + + Rewrite e2e tests to use ``testcontainers`` with ephemeral SSH keys + + Format codebase with ``ruff`` + - v.0.X.Y (`V0idk`_, `Bruno Inec`_, `alex3d`_) + Remove the potential deadlock that is associated with threading.Lock (`#231`_) + Remove the hidden modification of the logger in cases where a custom logger is used. (`#250`_) @@ -150,6 +161,7 @@ CHANGELOG + ``open`` function (`Pahaz`_) +.. _Deepnote: https://github.com/deepnote .. _Pahaz: https://github.com/pahaz .. _Cameron Maske: https://github.com/cameronmaske .. _Gustavo Machado: https://github.com/gdmachado diff --git a/e2e_tests/__init__.py b/e2e_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/e2e_tests/conftest.py b/e2e_tests/conftest.py new file mode 100644 index 00000000..fdac0186 --- /dev/null +++ b/e2e_tests/conftest.py @@ -0,0 +1,162 @@ +import shutil +import tempfile +import time +from pathlib import Path +from textwrap import dedent + +import paramiko +import pytest +from testcontainers.core.container import DockerContainer +from testcontainers.core.network import Network +from testcontainers.core.wait_strategies import LogMessageWaitStrategy +from testcontainers.core.waiting_utils import wait_for_logs + +PG_USER = "postgres" +PG_PASSWORD = "postgres" +PG_DB = "main" + +MYSQL_USER = "mysql" +MYSQL_PASSWORD = "mysql" +MYSQL_DB = "main" + +MONGO_USER = "mongo" +MONGO_PASSWORD = "mongo" +MONGO_DB = "main" + +OPENSSH_SERVER_IMAGE = ( + "linuxserver/openssh-server:10.2_p1-r0-ls225" + "@sha256:29d4e3f887596c4c2fc609f4e07040b08890a238178da400ffa2a602b55245bc" +) + + +def _generate_ssh_keypair(directory: Path) -> Path: + """Generate an ephemeral RSA keypair for the test run. + + Files are created with 0o666 permissions so the container's init process + (which runs chown to PUID/PGID) can take ownership regardless of host UID. + """ + key = paramiko.RSAKey.generate(bits=2048) + + private_key_path = directory / "ssh_host_rsa_key" + key.write_private_key_file(str(private_key_path)) + private_key_path.chmod(0o666) + + public_key_path = directory / "ssh_host_rsa_key.pub" + public_key_path.write_text(f"{key.get_name()} {key.get_base64()}") + public_key_path.chmod(0o666) + + return private_key_path + + +def _generate_sshd_config(directory: Path) -> Path: + """Generate an sshd_config that permits TCP forwarding. + + File is world-writable so the container's init can move/chown it. + """ + config_path = directory / "sshd_config" + config_path.write_text( + dedent("""\ + Port 2222 + PermitRootLogin no + PasswordAuthentication no + PubkeyAuthentication yes + AllowTcpForwarding yes + GatewayPorts no + X11Forwarding no + PrintMotd no + AcceptEnv LANG LC_* + Subsystem sftp /usr/lib/ssh/sftp-server + AuthorizedKeysFile .ssh/authorized_keys + """) + ) + config_path.chmod(0o666) + return config_path + + +@pytest.fixture(scope="session") +def e2e_infrastructure(): + """Spin up SSH + database containers on a shared Docker network.""" + tmp_key_dir = Path(tempfile.mkdtemp(prefix="sshtunnel-e2e-keys-")) + tmp_key_dir.chmod(0o777) + private_key_path = _generate_ssh_keypair(tmp_key_dir) + sshd_config_path = _generate_sshd_config(tmp_key_dir) + + # Keep a host-side copy of the private key that the container can't touch. + # The container's init chowns bind-mounted files to PUID/PGID, which makes + # the originals unreadable by the CI runner (different UID). + host_key_dir = Path(tempfile.mkdtemp(prefix="sshtunnel-e2e-hostkey-")) + host_private_key = host_key_dir / "ssh_host_rsa_key" + shutil.copy2(private_key_path, host_private_key) + host_private_key.chmod(0o600) + + try: + with Network() as network: + postgres = ( + DockerContainer("postgres:17") + .with_env("POSTGRES_USER", PG_USER) + .with_env("POSTGRES_PASSWORD", PG_PASSWORD) + .with_env("POSTGRES_DB", PG_DB) + .with_network(network) + .with_network_aliases("postgres-db") + ) + + mysql = ( + DockerContainer("mysql:8.4") + .with_env("MYSQL_DATABASE", MYSQL_DB) + .with_env("MYSQL_USER", MYSQL_USER) + .with_env("MYSQL_PASSWORD", MYSQL_PASSWORD) + .with_env("MYSQL_ROOT_PASSWORD", "rootpw") + .with_network(network) + .with_network_aliases("mysql-db") + ) + + mongo = ( + DockerContainer("mongo:8") + .with_env("MONGO_INITDB_ROOT_USERNAME", MONGO_USER) + .with_env("MONGO_INITDB_ROOT_PASSWORD", MONGO_PASSWORD) + .with_env("MONGO_INITDB_DATABASE", MONGO_DB) + .with_network(network) + .with_network_aliases("mongo-db") + ) + + ssh = ( + DockerContainer(OPENSSH_SERVER_IMAGE) + .with_env("PUID", "1000") + .with_env("PGID", "1000") + .with_env("TZ", "UTC") + .with_env("PUBLIC_KEY_FILE", "/config/ssh_host_keys/ssh_host_rsa_key.pub") + .with_env("SUDO_ACCESS", "false") + .with_env("PASSWORD_ACCESS", "false") + .with_env("USER_NAME", "linuxserver") + .with_env("LISTEN_PORT", "2222") + .with_volume_mapping(str(tmp_key_dir), "/config/ssh_host_keys", "rw") + .with_volume_mapping(str(sshd_config_path), "/config/sshd_config", "rw") + .with_exposed_ports(2222) + .with_network(network) + .with_network_aliases("ssh-server") + ) + + with postgres, mysql, mongo, ssh: + wait_for_logs( + postgres, LogMessageWaitStrategy("database system is ready to accept connections"), timeout=60 + ) + wait_for_logs(mysql, LogMessageWaitStrategy("port: 3306"), timeout=90) + wait_for_logs(mongo, LogMessageWaitStrategy("Waiting for connections"), timeout=60) + wait_for_logs(ssh, LogMessageWaitStrategy("done."), timeout=60) + + # The SSH server logs "done." before sshd is fully accepting connections. + # Without this grace period, early tunnel attempts get "Connection refused". + time.sleep(2) + + yield { + "ssh_host": ssh.get_container_host_ip(), + "ssh_port": int(ssh.get_exposed_port(2222)), + "ssh_username": "linuxserver", + "ssh_pkey": str(host_private_key), + "pg": {"host": "postgres-db", "port": 5432}, + "mysql": {"host": "mysql-db", "port": 3306}, + "mongo": {"host": "mongo-db", "port": 27017}, + } + finally: + shutil.rmtree(tmp_key_dir, ignore_errors=True) + shutil.rmtree(host_key_dir, ignore_errors=True) diff --git a/e2e_tests/docker-compose.yaml b/e2e_tests/docker-compose.yaml deleted file mode 100644 index bb4d998c..00000000 --- a/e2e_tests/docker-compose.yaml +++ /dev/null @@ -1,62 +0,0 @@ ---- -version: "2.1" -services: - ssh: - image: linuxserver/openssh-server:version-9.1_p1-r2 - container_name: openssh-server - hostname: openssh-server - environment: - - PUID=1000 - - PGID=1000 - - TZ=Europe/London - - PUBLIC_KEY_FILE=/config/ssh_host_keys/ssh_host_rsa_key.pub - - SUDO_ACCESS=false - - PASSWORD_ACCESS=false - - USER_NAME=linuxserver - - LISTEN_PORT=2222 - volumes: - - ./ssh-server-config:/config/ssh_host_keys - ports: - - "127.0.0.1:2223:2222" - networks: - - inner - - postgresdb: - image: postgres:13.0 - command: ["postgres", "-c", "log_statement=all"] - environment: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: main - networks: - inner: - ipv4_address: 10.5.0.5 - - mysqldb: - image: mysql:8.0.33 - environment: - MYSQL_DATABASE: main - MYSQL_USER: mysql - MYSQL_PASSWORD: mysql - MYSQL_ROOT_PASSWORD: mysqlroot - networks: - inner: - ipv4_address: 10.5.0.6 - - mongodb: - image: mongo:3.6.23 - environment: - MONGO_INITDB_ROOT_USERNAME: mongo - MONGO_INITDB_ROOT_PASSWORD: mongo - MONGO_INITDB_DATABASE: main - networks: - inner: - ipv4_address: 10.5.0.7 - -networks: - inner: - driver: bridge - ipam: - config: - - subnet: 10.5.0.0/16 - gateway: 10.5.0.1 diff --git a/e2e_tests/run_docker_e2e_db_tests.py b/e2e_tests/run_docker_e2e_db_tests.py deleted file mode 100644 index b9ea4df6..00000000 --- a/e2e_tests/run_docker_e2e_db_tests.py +++ /dev/null @@ -1,246 +0,0 @@ -import select -import traceback -import sys -import os -import time -from sshtunnel import SSHTunnelForwarder -import sshtunnel -import logging -import threading -import paramiko - -sshtunnel.DEFAULT_LOGLEVEL = 1 -logging.basicConfig( - format='%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/%(lineno)04d@%(module)-10.9s| %(message)s', level=1) -logger = logging.root - -SSH_SERVER_ADDRESS = ('127.0.0.1', 2223) -SSH_SERVER_USERNAME = 'linuxserver' -SSH_PKEY = os.path.join(os.path.dirname(__file__), 'ssh-server-config', 'ssh_host_rsa_key') -SSH_SERVER_REMOTE_SIDE_ADDRESS_PG = ('10.5.0.5', 5432) -SSH_SERVER_REMOTE_SIDE_ADDRESS_MYSQL = ('10.5.0.6', 3306) -SSH_SERVER_REMOTE_SIDE_ADDRESS_MONGO = ('10.5.0.7', 27017) - -PG_DATABASE_NAME = 'main' -PG_USERNAME = 'postgres' -PG_PASSWORD = 'postgres' -PG_QUERY = 'select version()' -PG_EXPECT = eval( - """('PostgreSQL 13.0 (Debian 13.0-1.pgdg100+1) on x86_64-pc-linux-gnu, compiled by gcc (Debian 8.3.0-6) 8.3.0, 64-bit',)""") - -MYSQL_DATABASE_NAME = 'main' -MYSQL_USERNAME = 'mysql' -MYSQL_PASSWORD = 'mysql' -MYSQL_QUERY = 'select version()' -MYSQL_EXPECT = (('8.0.33',),) - -MONGO_DATABASE_NAME = 'main' -MONGO_USERNAME = 'mongo' -MONGO_PASSWORD = 'mongo' -MONGO_QUERY = lambda client, db: client.server_info() -MONGO_EXPECT = eval( - """{'version': '3.6.23', 'gitVersion': 'd352e6a4764659e0d0350ce77279de3c1f243e5c', 'modules': [], 'allocator': 'tcmalloc', 'javascriptEngine': 'mozjs', 'sysInfo': 'deprecated', 'versionArray': [3, 6, 23, 0], 'openssl': {'running': 'OpenSSL 1.0.2g 1 Mar 2016', 'compiled': 'OpenSSL 1.0.2g 1 Mar 2016'}, 'buildEnvironment': {'distmod': 'ubuntu1604', 'distarch': 'x86_64', 'cc': '/opt/mongodbtoolchain/v2/bin/gcc: gcc (GCC) 5.4.0', 'ccflags': '-fno-omit-frame-pointer -fno-strict-aliasing -ggdb -pthread -Wall -Wsign-compare -Wno-unknown-pragmas -Winvalid-pch -Werror -O2 -Wno-unused-local-typedefs -Wno-unused-function -Wno-deprecated-declarations -Wno-unused-but-set-variable -Wno-missing-braces -fstack-protector-strong -fno-builtin-memcmp', 'cxx': '/opt/mongodbtoolchain/v2/bin/g++: g++ (GCC) 5.4.0', 'cxxflags': '-Woverloaded-virtual -Wno-maybe-uninitialized -std=c++14', 'linkflags': '-pthread -Wl,-z,now -rdynamic -Wl,--fatal-warnings -fstack-protector-strong -fuse-ld=gold -Wl,--build-id -Wl,--hash-style=gnu -Wl,-z,noexecstack -Wl,--warn-execstack -Wl,-z,relro', 'target_arch': 'x86_64', 'target_os': 'linux'}, 'bits': 64, 'debug': False, 'maxBsonObjectSize': 16777216, 'storageEngines': ['devnull', 'ephemeralForTest', 'mmapv1', 'wiredTiger'], 'ok': 1.0}""") - - -def run_postgres_query(port, query=PG_QUERY): - import psycopg2 - - ASYNC_OK = 1 - ASYNC_READ_TIMEOUT = 2 - ASYNC_WRITE_TIMEOUT = 3 - ASYNC_TIMEOUT = 0.2 - - def wait(conn): - while 1: - state = conn.poll() - if state == psycopg2.extensions.POLL_OK: - break - elif state == psycopg2.extensions.POLL_WRITE: - select.select([], [conn.fileno()], []) - elif state == psycopg2.extensions.POLL_READ: - select.select([conn.fileno()], [], []) - else: - raise psycopg2.OperationalError( - "poll() returned %s from _wait function" % state) - - def wait_timeout(conn): - while 1: - state = conn.poll() - if state == psycopg2.extensions.POLL_OK: - return ASYNC_OK - elif state == psycopg2.extensions.POLL_WRITE: - # Wait for the given time and then check the return status - # If three empty lists are returned then the time-out is - # reached. - timeout_status = select.select( - [], [conn.fileno()], [], ASYNC_TIMEOUT - ) - if timeout_status == ([], [], []): - return ASYNC_WRITE_TIMEOUT - elif state == psycopg2.extensions.POLL_READ: - # Wait for the given time and then check the return status - # If three empty lists are returned then the time-out is - # reached. - timeout_status = select.select( - [conn.fileno()], [], [], ASYNC_TIMEOUT - ) - if timeout_status == ([], [], []): - return ASYNC_READ_TIMEOUT - else: - raise psycopg2.OperationalError( - "poll() returned %s from _wait_timeout function" % state - ) - - pg_conn = psycopg2.connect( - host='127.0.0.1', - hostaddr='127.0.0.1', - port=port, - database=PG_DATABASE_NAME, - user=PG_USERNAME, - password=PG_PASSWORD, - sslmode='disable', - async_=1 - ) - wait(pg_conn) - cur = pg_conn.cursor() - cur.execute(query) - res = wait_timeout(cur.connection) - while res != ASYNC_OK: - res = wait_timeout(cur.connection) - return cur.fetchone() - - -def run_mysql_query(port, query=MYSQL_QUERY): - import pymysql - conn = pymysql.connect( - host='127.0.0.1', - port=port, - user=MYSQL_USERNAME, - password=MYSQL_PASSWORD, - database=MYSQL_DATABASE_NAME, - connect_timeout=5, - read_timeout=5) - cursor = conn.cursor() - cursor.execute(query) - return cursor.fetchall() - - -def run_mongo_query(port, query=MONGO_QUERY): - import pymongo - client = pymongo.MongoClient('127.0.0.1', port) - db = client[MONGO_DATABASE_NAME] - return query(client, db) - - -def create_tunnel(): - logging.info('Creating SSHTunnelForwarder... (sshtunnel v%s, paramiko v%s)', - sshtunnel.__version__, paramiko.__version__) - tunnel = SSHTunnelForwarder( - SSH_SERVER_ADDRESS, - ssh_username=SSH_SERVER_USERNAME, - ssh_pkey=SSH_PKEY, - remote_bind_addresses=[ - SSH_SERVER_REMOTE_SIDE_ADDRESS_PG, SSH_SERVER_REMOTE_SIDE_ADDRESS_MYSQL, - SSH_SERVER_REMOTE_SIDE_ADDRESS_MONGO, - ], - logger=logger, - ) - return tunnel - - -def start(tunnel): - try: - logging.info('Trying to start ssh tunnel...') - tunnel.start() - except Exception as e: - logging.exception('Tunnel start exception: %r', e) - raise - - -def run_db_queries(tunnel): - result1, result2, result3 = None, None, None - - try: - logging.info('Trying to run PG query...') - result1 = run_postgres_query(tunnel.local_bind_ports[0]) - logging.info('PG query: %r', result1) - except Exception as e: - logging.exception('PG query exception: %r', e) - raise - - try: - logging.info('Trying to run MYSQL query...') - result2 = run_mysql_query(tunnel.local_bind_ports[1]) - logging.info('MYSQL query: %r', result2) - except Exception as e: - logging.exception('MYSQL query exception: %r', e) - raise - - try: - logging.info('Trying to run MONGO query...') - result3 = run_mongo_query(tunnel.local_bind_ports[2]) - logging.info('MONGO query: %r', result3) - except Exception as e: - logging.exception('MONGO query exception: %r', e) - raise - - return result1, result2, result3 - - -def wait_and_check_or_restart_if_required(tunnel, i=1): - logging.warning('Sleeping for %s second...', i) - while i: - time.sleep(1) - if i % 10 == 0: - logging.info('Running tunnel.check_tunnels... (i=%s)', i) - tunnel.check_tunnels() - logging.info('Check result: %r (i=%s)', tunnel.tunnel_is_up, i) - if not tunnel.is_active: - logging.warning('Tunnel is DOWN! restarting ...') - tunnel.restart() - i -= 1 - - -def stop(tunnel, force=True): - try: - logging.info('Trying to stop resources...') - tunnel.stop(force=force) - except Exception as e: - logging.exception('Tunnel stop exception: %r', e) - raise - - -def show_threading_state_if_required(): - current_threads = list(threading.enumerate()) - if len(current_threads) > 1: - logging.warning('[1] THREAD INFO') - logging.info('Threads: %r', current_threads) - logging.info('Threads.daemon: %r', [x.daemon for x in current_threads]) - - if len(current_threads) > 1: - logging.warning('[2] STACK INFO') - code = ["\n\n*** STACKTRACE - START ***\n"] - for threadId, stack in sys._current_frames().items(): - code.append("\n# ThreadID: %s" % threadId) - for filename, lineno, name, line in traceback.extract_stack(stack): - code.append('File: "%s", line %d, in %s' % (filename, lineno, name)) - if line: - code.append(" %s" % (line.strip())) - code.append("\n*** STACKTRACE - END ***\n\n") - logging.info('\n'.join(code)) - - -if __name__ == '__main__': - logging.warning('RUN') - tunnel = create_tunnel() - start(tunnel) - res = run_db_queries(tunnel) - stop(tunnel) - wait_and_check_or_restart_if_required(tunnel) - show_threading_state_if_required() - - logging.info('RESULT POSTGRES: %r', res[0]) - logging.info('RESULT MYSQL: %r', res[1]) - logging.info('RESULT MONGO: %r', res[2]) - - assert res == (PG_EXPECT, MYSQL_EXPECT, MONGO_EXPECT) diff --git a/e2e_tests/run_docker_e2e_hangs_tests.py b/e2e_tests/run_docker_e2e_hangs_tests.py deleted file mode 100644 index 0ec7449e..00000000 --- a/e2e_tests/run_docker_e2e_hangs_tests.py +++ /dev/null @@ -1,13 +0,0 @@ -import logging -import sshtunnel -import os - - -if __name__ == '__main__': - path = os.path.join(os.path.dirname(__file__), 'run_docker_e2e_db_tests.py') - with open(path) as f: - exec(f.read()) - logging.warning('RUN') - tunnel = create_tunnel() - start(tunnel) - logging.warning('EOF') diff --git a/e2e_tests/ssh-server-config/ssh_host_rsa_key b/e2e_tests/ssh-server-config/ssh_host_rsa_key deleted file mode 100644 index 01f74b4d..00000000 --- a/e2e_tests/ssh-server-config/ssh_host_rsa_key +++ /dev/null @@ -1,38 +0,0 @@ ------BEGIN OPENSSH PRIVATE KEY----- -b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn -NhAAAAAwEAAQAAAYEAvIGU0pRpThhIcaSPrg2+v7cXl+QcG0icb45hfD44yrCoXkpJp7nh -Hv0ObZL2Y1cG7eeayYF4AqD3kwQ7W89GN6YO9b/mkJgawk0/YLUyojTS9dbcTbdkfPzyUa -vTMDjly+PIjfiWOEnUgPf1y3xONLkJU0ILyTmgTzSIMNdKngtdCGfytBCuNiPKU8hEdEVt -82ebqgtLoSYn9cUcVVz6LewzUh8+YtoPb8Z/BIVEzU37HiE9MOYIBXjo1AEJSnOCkjwlVl -PzLhcXKTPht0iwv/KnZNNg0LDmnU/z0n+nPq/EMflum8jRYbgp0C5hksPdc8e0eEKd9gak -t7B0ta3Mjt5b8HPQdBGZI/QFufEnSOxfJmoK4Bvjy/oUwE0hGU6po5g+4T2j6Bqqm2I+yV -EbkP/UiuD/kEiT0C3yCV547gIDjN2ME9tGJDkd023BFvqn3stFVVZ5WsisRKGc+lvTfqeA -JyKFaVt5a23y68ztjEMVrMLksRuEF8gG5kV7EGyjAAAFiCzGBRksxgUZAAAAB3NzaC1yc2 -EAAAGBALyBlNKUaU4YSHGkj64Nvr+3F5fkHBtInG+OYXw+OMqwqF5KSae54R79Dm2S9mNX -Bu3nmsmBeAKg95MEO1vPRjemDvW/5pCYGsJNP2C1MqI00vXW3E23ZHz88lGr0zA45cvjyI -34ljhJ1ID39ct8TjS5CVNCC8k5oE80iDDXSp4LXQhn8rQQrjYjylPIRHRFbfNnm6oLS6Em -J/XFHFVc+i3sM1IfPmLaD2/GfwSFRM1N+x4hPTDmCAV46NQBCUpzgpI8JVZT8y4XFykz4b -dIsL/yp2TTYNCw5p1P89J/pz6vxDH5bpvI0WG4KdAuYZLD3XPHtHhCnfYGpLewdLWtzI7e -W/Bz0HQRmSP0BbnxJ0jsXyZqCuAb48v6FMBNIRlOqaOYPuE9o+gaqptiPslRG5D/1Irg/5 -BIk9At8gleeO4CA4zdjBPbRiQ5HdNtwRb6p97LRVVWeVrIrEShnPpb036ngCcihWlbeWtt -8uvM7YxDFazC5LEbhBfIBuZFexBsowAAAAMBAAEAAAGAflHjdb2oV4HkQetBsSRa18QM1m -cxAoOE+SiTYRudGQ6KtSzY8MGZ/xca7QiXfXhbF1+llTTiQ/i0Dtu+H0blyfLIgZwIGIsl -G2GCf/7MoG//kmhaFuY3O56Rj3MyQVVPgHLy+VhE6hFniske+C4jhicc/aL7nOu15n3Qad -JLmV8KB9EIjevDoloXgk9ot/WyuXKLmMaa9rFIA+UDmJyGtfFbbsOrHbj8sS11/oSD14RT -LBygEb2EUI52j2LmY/LEvUL+59oCuJ6Y/h+pMdFeuHJzGjrVb573KnGwejzY24HHzzebrC -Q+9NyVCTyizPHNu9w52/GPEZQFQBi7o9cDMd3ITZEPIaIvDHsUwPXaHUBHy/XHQTs8pDqk -zCMcAs5zdzao2I0LQ+ZFYyvl1rue82ITjDISX1WK6nFYLBVXugi0rLGEdH6P+Psfl3uCIf -aW7c12/BpZz2Pql5AuO1wsu4rmz2th68vaC/0IDqWekIbW9qihFbqnhfAxRsIURjpBAAAA -wDhIQPsj9T9Vud3Z/TZjiAKCPbg3zi082u1GMMxXnNQtKO3J35wU7VUcAxAzosWr+emMqS -U0qW+a5RXr3sqUOqH85b5+Xw0yv2sTr2pL0ALFW7Tq1mesCc3K0So3Yo30pWRIOxYM9ihm -E4ci/3mN5kcKWwvLLomFPRU9u0XtIGKnF/cNByTuz9fceR6Pi6mQXZawv+OOMiBeu0gbyp -F1uVe8PCshzCrWTE3UjRpQxy9gizvSbGZyGQi1Lm42JXKG3wAAAMEA4r4CLM1xsyxBBMld -rxiTqy6bfrZjKkT5MPjBjp+57i5kW9NVqGCnIy/m98pLTuKjTCDmUuWQXS+oqhHw5vq/wj -RvQYqkJDz1UGmC1lD2qyqERjOiWa8/iy4dXSLeHCT70+/xR2dBb0z8cT++yZEqLdEZSnHG -yRaZMHot1OohVDqJS8nEbxOzgPGdopRMiX6ws/p5/k9YAGkHx0hszA8cn/Tk2/mdS5lugw -Y7mdXzfcKvxkgoFrG7XowqRVrozcvDAAAAwQDU1ITasquNLaQhKNqiHx/N7bvKVO33icAx -NdShqJEWx/g9idvQ25sA1Ubc1a+Ot5Lgfrs2OBKe+LgSmPAZOjv4ShqBHtsSh3am8/K1xR -gQKgojLL4FhtgxtwoZrVvovZHGV3g2A28BRGbKIGVGPsOszJALU7jlLlcTHlB7SCQBI8FQ -vTi2UEsfTmA22NnuVPITeqbmAQQXkSZcZbpbvdc0vQzp/3iOb/OCrIMET3HqVEMyQVsVs6 -xa9026AMTGLaEAAAATcm9vdEBvcGVuc3NoLXNlcnZlcg== ------END OPENSSH PRIVATE KEY----- diff --git a/e2e_tests/ssh-server-config/ssh_host_rsa_key.pub b/e2e_tests/ssh-server-config/ssh_host_rsa_key.pub deleted file mode 100644 index 3288d9e1..00000000 --- a/e2e_tests/ssh-server-config/ssh_host_rsa_key.pub +++ /dev/null @@ -1 +0,0 @@ -ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC8gZTSlGlOGEhxpI+uDb6/txeX5BwbSJxvjmF8PjjKsKheSkmnueEe/Q5tkvZjVwbt55rJgXgCoPeTBDtbz0Y3pg71v+aQmBrCTT9gtTKiNNL11txNt2R8/PJRq9MwOOXL48iN+JY4SdSA9/XLfE40uQlTQgvJOaBPNIgw10qeC10IZ/K0EK42I8pTyER0RW3zZ5uqC0uhJif1xRxVXPot7DNSHz5i2g9vxn8EhUTNTfseIT0w5ggFeOjUAQlKc4KSPCVWU/MuFxcpM+G3SLC/8qdk02DQsOadT/PSf6c+r8Qx+W6byNFhuCnQLmGSw91zx7R4Qp32BqS3sHS1rcyO3lvwc9B0EZkj9AW58SdI7F8magrgG+PL+hTATSEZTqmjmD7hPaPoGqqbYj7JURuQ/9SK4P+QSJPQLfIJXnjuAgOM3YwT20YkOR3TbcEW+qfey0VVVnlayKxEoZz6W9N+p4AnIoVpW3lrbfLrzO2MQxWswuSxG4QXyAbmRXsQbKM= root@openssh-server diff --git a/e2e_tests/ssh-server-config/sshd_config b/e2e_tests/ssh-server-config/sshd_config deleted file mode 100644 index 6bbe5a14..00000000 --- a/e2e_tests/ssh-server-config/sshd_config +++ /dev/null @@ -1,119 +0,0 @@ -# $OpenBSD: sshd_config,v 1.103 2018/04/09 20:41:22 tj Exp $ - -# This is the sshd server system-wide configuration file. See -# sshd_config(5) for more information. - -# This sshd was compiled with PATH=/bin:/usr/bin:/sbin:/usr/sbin - -# The strategy used for options in the default sshd_config shipped with -# OpenSSH is to specify options with their default value where -# possible, but leave them commented. Uncommented options override the -# default value. - -Port 2222 -#AddressFamily any -#ListenAddress 0.0.0.0 -#ListenAddress :: - -#HostKey /etc/ssh/ssh_host_rsa_key -#HostKey /etc/ssh/ssh_host_ecdsa_key -#HostKey /etc/ssh/ssh_host_ed25519_key - -# Ciphers and keying -#RekeyLimit default none - -# Logging -#SyslogFacility AUTH -LogLevel DEBUG - -# Authentication: - -#LoginGraceTime 2m -#PermitRootLogin prohibit-password -#StrictModes yes -#MaxAuthTries 6 -#MaxSessions 10 - -#PubkeyAuthentication yes - -# The default is to check both .ssh/authorized_keys and .ssh/authorized_keys2 -# but this is overridden so installations will only check .ssh/authorized_keys -AuthorizedKeysFile .ssh/authorized_keys - -#AuthorizedPrincipalsFile none - -#AuthorizedKeysCommand none -#AuthorizedKeysCommandUser nobody - -# For this to work you will also need host keys in /etc/ssh/ssh_known_hosts -#HostbasedAuthentication no -# Change to yes if you don't trust ~/.ssh/known_hosts for -# HostbasedAuthentication -#IgnoreUserKnownHosts no -# Don't read the user's ~/.rhosts and ~/.shosts files -#IgnoreRhosts yes - -# To disable tunneled clear text passwords, change to no here! -PasswordAuthentication no -#PermitEmptyPasswords no - -# Change to no to disable s/key passwords -#ChallengeResponseAuthentication yes - -# Kerberos options -#KerberosAuthentication no -#KerberosOrLocalPasswd yes -#KerberosTicketCleanup yes -#KerberosGetAFSToken no - -# GSSAPI options -#GSSAPIAuthentication no -#GSSAPICleanupCredentials yes - -# Set this to 'yes' to enable PAM authentication, account processing, -# and session processing. If this is enabled, PAM authentication will -# be allowed through the ChallengeResponseAuthentication and -# PasswordAuthentication. Depending on your PAM configuration, -# PAM authentication via ChallengeResponseAuthentication may bypass -# the setting of "PermitRootLogin without-password". -# If you just want the PAM account and session checks to run without -# PAM authentication, then enable this but set PasswordAuthentication -# and ChallengeResponseAuthentication to 'no'. -#UsePAM no - -#AllowAgentForwarding yes -# Feel free to re-enable these if your use case requires them. -AllowTcpForwarding yes -GatewayPorts no -X11Forwarding no -#X11DisplayOffset 10 -#X11UseLocalhost yes -#PermitTTY yes -PrintMotd no -#PrintLastLog yes -#TCPKeepAlive yes -#PermitUserEnvironment no -#Compression delayed -#ClientAliveInterval 0 -#ClientAliveCountMax 3 -#UseDNS no -PidFile /config/sshd.pid -#MaxStartups 10:30:100 -#PermitTunnel no -#ChrootDirectory none -#VersionAddendum none - -# no default banner path -Banner none - -# override default of no subsystems -Subsystem sftp /usr/lib/ssh/sftp-server -u 022 - -# Example of overriding settings on a per-user basis -#Match User anoncvs -# X11Forwarding no -# AllowTcpForwarding no -# PermitTTY no -# ForceCommand cvs server - -# !! \ No newline at end of file diff --git a/e2e_tests/test_db_tunnel.py b/e2e_tests/test_db_tunnel.py new file mode 100644 index 00000000..25a9f7b4 --- /dev/null +++ b/e2e_tests/test_db_tunnel.py @@ -0,0 +1,154 @@ +"""E2E tests: verify SSH tunnels to real database containers.""" + +import pytest +from sshtunnel import SSHTunnelForwarder + +from .conftest import PG_USER, PG_PASSWORD, PG_DB, MYSQL_USER, MYSQL_PASSWORD, MYSQL_DB, MONGO_USER, MONGO_PASSWORD + +pytestmark = pytest.mark.timeout(120) + + +def _make_tunnel(infra, remote_bind_addresses): + """Create (but don't start) an SSHTunnelForwarder from the fixture data.""" + return SSHTunnelForwarder( + (infra["ssh_host"], infra["ssh_port"]), + ssh_username=infra["ssh_username"], + ssh_pkey=infra["ssh_pkey"], + remote_bind_addresses=remote_bind_addresses, + local_bind_addresses=[("127.0.0.1", 0) for _ in remote_bind_addresses], + ) + + +class TestPostgresTunnel: + def test_query_via_tunnel(self, e2e_infrastructure): + import psycopg2 + + infra = e2e_infrastructure + pg = infra["pg"] + + with _make_tunnel(infra, [(pg["host"], pg["port"])]) as tunnel: + conn = psycopg2.connect( + host="127.0.0.1", + port=tunnel.local_bind_port, + database=PG_DB, + user=PG_USER, + password=PG_PASSWORD, + connect_timeout=10, + ) + cur = conn.cursor() + cur.execute("SELECT version()") + result = cur.fetchone()[0] + conn.close() + + assert "PostgreSQL" in result + + +class TestMySQLTunnel: + def test_query_via_tunnel(self, e2e_infrastructure): + import pymysql + + infra = e2e_infrastructure + mysql = infra["mysql"] + + with _make_tunnel(infra, [(mysql["host"], mysql["port"])]) as tunnel: + conn = pymysql.connect( + host="127.0.0.1", + port=tunnel.local_bind_port, + user=MYSQL_USER, + password=MYSQL_PASSWORD, + database=MYSQL_DB, + connect_timeout=10, + read_timeout=10, + ) + cursor = conn.cursor() + cursor.execute("SELECT version()") + result = cursor.fetchone()[0] + conn.close() + + assert result # non-empty version string + + +class TestMongoTunnel: + def test_query_via_tunnel(self, e2e_infrastructure): + import pymongo + + infra = e2e_infrastructure + mongo = infra["mongo"] + + with _make_tunnel(infra, [(mongo["host"], mongo["port"])]) as tunnel: + client = pymongo.MongoClient( + "127.0.0.1", + tunnel.local_bind_port, + username=MONGO_USER, + password=MONGO_PASSWORD, + serverSelectionTimeoutMS=10000, + connectTimeoutMS=10000, + socketTimeoutMS=10000, + ) + info = client.server_info() + client.close() + + assert "version" in info + + +class TestMultiTunnel: + def test_all_databases_via_single_tunnel(self, e2e_infrastructure): + """Open a single SSH tunnel forwarding to all three databases at once.""" + import psycopg2 + import pymysql + import pymongo + + infra = e2e_infrastructure + pg = infra["pg"] + mysql = infra["mysql"] + mongo = infra["mongo"] + + remote_binds = [ + (pg["host"], pg["port"]), + (mysql["host"], mysql["port"]), + (mongo["host"], mongo["port"]), + ] + + with _make_tunnel(infra, remote_binds) as tunnel: + pg_port, mysql_port, mongo_port = tunnel.local_bind_ports + + # Postgres + pg_conn = psycopg2.connect( + host="127.0.0.1", + port=pg_port, + database=PG_DB, + user=PG_USER, + password=PG_PASSWORD, + connect_timeout=10, + ) + pg_cur = pg_conn.cursor() + pg_cur.execute("SELECT 1") + assert pg_cur.fetchone() == (1,) + pg_conn.close() + + # MySQL + mysql_conn = pymysql.connect( + host="127.0.0.1", + port=mysql_port, + user=MYSQL_USER, + password=MYSQL_PASSWORD, + database=MYSQL_DB, + connect_timeout=10, + ) + mysql_cur = mysql_conn.cursor() + mysql_cur.execute("SELECT 1") + assert mysql_cur.fetchone() == (1,) + mysql_conn.close() + + # MongoDB + mongo_client = pymongo.MongoClient( + "127.0.0.1", + mongo_port, + username=MONGO_USER, + password=MONGO_PASSWORD, + serverSelectionTimeoutMS=10000, + connectTimeoutMS=10000, + socketTimeoutMS=10000, + ) + assert mongo_client.server_info()["ok"] == 1.0 + mongo_client.close() diff --git a/e2e_tests/test_hang.py b/e2e_tests/test_hang.py new file mode 100644 index 00000000..351f1bb1 --- /dev/null +++ b/e2e_tests/test_hang.py @@ -0,0 +1,50 @@ +"""E2E test: verify an SSH tunnel starts, stays alive briefly, and stops cleanly.""" + +import time + +import pytest +from sshtunnel import SSHTunnelForwarder + +pytestmark = pytest.mark.timeout(60) + + +def test_tunnel_does_not_hang_on_start_stop(e2e_infrastructure): + """Start a tunnel, verify it's alive, then stop it without hanging.""" + infra = e2e_infrastructure + pg = infra["pg"] + + tunnel = SSHTunnelForwarder( + (infra["ssh_host"], infra["ssh_port"]), + ssh_username=infra["ssh_username"], + ssh_pkey=infra["ssh_pkey"], + remote_bind_addresses=[(pg["host"], pg["port"])], + ) + + tunnel.start() + assert tunnel.is_alive + assert tunnel.is_active + + time.sleep(2) + + assert tunnel.is_alive + assert tunnel.is_active + + tunnel.stop() + assert not tunnel.is_alive + + +def test_tunnel_context_manager_does_not_hang(e2e_infrastructure): + """Verify the context manager enters and exits without hanging.""" + infra = e2e_infrastructure + pg = infra["pg"] + + with SSHTunnelForwarder( + (infra["ssh_host"], infra["ssh_port"]), + ssh_username=infra["ssh_username"], + ssh_pkey=infra["ssh_pkey"], + remote_bind_addresses=[(pg["host"], pg["port"])], + ) as tunnel: + assert tunnel.is_alive + time.sleep(1) + + assert not tunnel.is_alive diff --git a/noxfile.py b/noxfile.py new file mode 100644 index 00000000..e39d9041 --- /dev/null +++ b/noxfile.py @@ -0,0 +1,38 @@ +import nox + +nox.options.default_venv_backend = "uv" +nox.options.reuse_venv = "yes" + +PYTHON_VERSIONS = ["3.10", "3.11", "3.12", "3.13"] +PARAMIKO_VERSIONS = ["paramiko>=3.4,<4", "paramiko>=4,<5", "paramiko>=5,<6"] + + +@nox.session(python=PYTHON_VERSIONS) +@nox.parametrize("paramiko", PARAMIKO_VERSIONS) +def tests(session: nox.Session, paramiko: str) -> None: + """Run the test suite against a specific paramiko version.""" + session.install("-e", ".[test]") + session.install(paramiko) + session.run( + "pytest", + "tests/", + "-n4", + "--cov=sshtunnel", + "--cov-report=term", + *session.posargs, + ) + + +@nox.session(python=["3.13"]) +def e2e(session: nox.Session) -> None: + """Run e2e tests with testcontainers (requires Docker).""" + session.install("-e", ".[test,e2e]") + session.run("pytest", "e2e_tests/", "-v", *session.posargs) + + +@nox.session +def lint(session: nox.Session) -> None: + """Run ruff linter and formatter checks.""" + session.install("ruff") + session.run("ruff", "check", ".") + session.run("ruff", "format", "--check", ".") diff --git a/pyproject.toml b/pyproject.toml index b0471b7f..23fc1faf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,102 @@ [build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta:__legacy__" \ No newline at end of file +requires = ["hatchling", "uv-dynamic-versioning"] +build-backend = "hatchling.build" + +[project] +name = "deepnote-sshtunnel" +dynamic = ["version"] +description = "Pure python SSH tunnels (Deepnote fork)" +readme = "README.rst" +requires-python = ">=3.10" +license = "MIT" +authors = [ + {name = "Pahaz White", email = "pahaz.white@gmail.com"}, +] +maintainers = [ + {name = "Deepnote", email = "product-engineers@deepnote.com"}, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Software Development :: Build Tools", +] +keywords = ["ssh", "tunnel", "paramiko", "proxy", "tcp-forward"] +dependencies = [ + "paramiko>=3.4,<6", +] + +[project.urls] +Homepage = "https://github.com/deepnote/sshtunnel" +Issues = "https://github.com/deepnote/sshtunnel/issues" +Changelog = "https://github.com/deepnote/sshtunnel/blob/main/changelog.rst" + +[project.scripts] +sshtunnel = "sshtunnel:_cli_main" + +[project.optional-dependencies] +test = [ + "pytest", + "pytest-cov", + "pytest-xdist", +] +e2e = [ + "testcontainers", + "psycopg2-binary", + "pymysql", + "pymongo", + "pytest", + "pytest-timeout", +] + +[dependency-groups] +dev = [ + "nox[uv]", + "ruff", + "pytest", + "pytest-cov", + "pytest-xdist", +] + +[tool.hatch.version] +source = "uv-dynamic-versioning" + +[tool.uv-dynamic-versioning] +fallback-version = "0.0.1" +pattern = "default-unprefixed" +vcs = "git" +style = "pep440" +metadata = false + +[tool.hatch.build.targets.wheel] +include = ["sshtunnel.py"] + +[tool.hatch.build.targets.sdist] +include = [ + "sshtunnel.py", + "README.rst", + "changelog.rst", + "docs.rst", + "LICENSE", +] + +[tool.ruff] +target-version = "py310" +line-length = 120 +exclude = ["docs"] + +[tool.ruff.lint] +select = ["E", "F", "W"] +ignore = [ + "E501", +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "--showlocals --durations=10" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 94f65bb8..00000000 --- a/setup.cfg +++ /dev/null @@ -1,15 +0,0 @@ -[bdist_wheel] -# This flag says that the code is written to work on both Python 2 and Python -# 3. If at all possible, it is good practice to do this. If you cannot, you -# will need to generate wheels for each Python version that you support. -universal=1 - -[check-manifest] -ignore = - .travis.yml - circle.yml - tox.ini - -[build_sphinx] -source-dir = docs/ -build-dir = docs/_build diff --git a/setup.py b/setup.py deleted file mode 100644 index ccaaab8c..00000000 --- a/setup.py +++ /dev/null @@ -1,137 +0,0 @@ -"""A setuptools based setup module. - -See: -https://packaging.python.org/en/latest/distributing.html -https://github.com/pypa/sampleproject -""" - -import re -from os import path -from codecs import open # To use a consistent encoding -from setuptools import setup # Always prefer setuptools over distutils - -here = path.abspath(path.dirname(__file__)) -name = 'sshtunnel' -description = 'Pure python SSH tunnels' -url = 'https://github.com/pahaz/sshtunnel' -ppa = 'https://pypi.python.org/packages/source/s/{0}/{0}-'.format(name) - -# Get the long description from the README file -with open(path.join(here, 'README.rst'), encoding='utf-8') as f: - long_description = f.read() -with open(path.join(here, 'docs.rst'), encoding='utf-8') as f: - documentation = f.read() -with open(path.join(here, 'changelog.rst'), encoding='utf-8') as f: - changelog = f.read() - -with open(path.join(here, name + '.py'), encoding='utf-8') as f: - data = f.read() - version = eval(re.search("__version__[ ]*=[ ]*([^\r\n]+)", data).group(1)) - - -setup( - name=name, - - # Versions should comply with PEP440. For a discussion on single-sourcing - # the version across setup.py and the project code, see - # https://packaging.python.org/en/latest/single_source_version.html - version=version, - - description=description, - long_description='\n'.join((long_description, documentation, changelog)), - long_description_content_type='text/x-rst', - - # The project's main homepage. - url=url, - download_url=ppa + version + '.zip', # noqa - - # Author details - author='Pahaz White', - author_email='pahaz.white@gmail.com', - - # Choose your license - license='MIT', - - # See https://pypi.python.org/pypi?%3Aaction=list_classifiers - classifiers=[ - # How mature is this project? Common values are - # 3 - Alpha - # 4 - Beta - # 5 - Production/Stable - 'Development Status :: 3 - Alpha', - - # Indicate who your project is intended for - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Build Tools', - - # Pick your license as you wish (should match "license" above) - 'License :: OSI Approved :: MIT License', - - # Specify the Python versions you support here. In particular, ensure - # that you indicate whether you support Python 2, Python 3 or both. - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - ], - - platforms=['unix', 'macos', 'windows'], - - # What does your project relate to? - keywords='ssh tunnel paramiko proxy tcp-forward', - - # You can just specify the packages manually here if your project is - # simple. Or you can use find_packages(). - # packages=find_packages(exclude=['contrib', 'docs', 'tests']), - - # Alternatively, if you want to distribute just a my_module.py, uncomment - # this: - py_modules=["sshtunnel"], - - # List run-time dependencies here. These will be installed by pip when - # your project is installed. For an analysis of "install_requires" vs pip's - # requirements files see: - # https://packaging.python.org/en/latest/requirements.html - install_requires=[ - 'paramiko>=2.7.2', - ], - - # List additional groups of dependencies here (e.g. development - # dependencies). You can install these using the following syntax, - # for example: - # $ pip install -e .[dev,test] - tests_require=[ - 'tox>=1.8.1', - ], - extras_require={ - 'dev': ['check-manifest'], - 'test': [ - 'tox>=1.8.1', - ], - 'build_sphinx': [ - 'sphinx', - 'sphinxcontrib-napoleon', - ], - }, - - # If there are data files included in your packages that need to be - # installed, specify them here. If using Python 2.6 or less, then these - # have to be included in MANIFEST.in as well. - package_data={ - 'tests': ['testrsa.key'], - }, - - # To provide executable scripts, use entry points in preference to the - # "scripts" keyword. Entry points provide cross-platform support and allow - # pip to create the appropriate form of executable for the target platform. - entry_points={ - 'console_scripts': [ - 'sshtunnel=sshtunnel:_cli_main', - ] - }, - -) diff --git a/sshtunnel.py b/sshtunnel.py index a7db0c44..5407c933 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ *sshtunnel* - Initiate SSH tunnels via a remote gateway. @@ -12,10 +11,12 @@ """ import os +import queue import random import string import sys import socket +import socketserver import getpass import logging import argparse @@ -23,23 +24,18 @@ import threading from select import select from binascii import hexlify +from importlib.metadata import version as _get_version import paramiko -if sys.version_info[0] < 3: # pragma: no cover - import Queue as queue - import SocketServer as socketserver - string_types = basestring, # noqa - input_ = raw_input # noqa -else: # pragma: no cover - import queue - import socketserver - string_types = str - input_ = input +string_types = str +input_ = input - -__version__ = '0.4.0' -__author__ = 'pahaz' +try: + __version__ = _get_version("deepnote-sshtunnel") +except Exception: + __version__ = "0.0.1" +__author__ = "pahaz" #: Timeout (seconds) for transport socket (``socket.settimeout``) @@ -50,24 +46,23 @@ _DAEMON = True #: Use daemon threads in connections _CONNECTION_COUNTER = 1 _DEPRECATIONS = { - 'ssh_address': 'ssh_address_or_host', - 'ssh_host': 'ssh_address_or_host', - 'ssh_private_key': 'ssh_pkey', - 'raise_exception_if_any_forwarder_have_a_problem': 'mute_exceptions' + "ssh_address": "ssh_address_or_host", + "ssh_host": "ssh_address_or_host", + "ssh_private_key": "ssh_pkey", + "raise_exception_if_any_forwarder_have_a_problem": "mute_exceptions", } # logging DEFAULT_LOGLEVEL = logging.ERROR #: default level if no logger passed (ERROR) TRACE_LEVEL = 1 -logging.addLevelName(TRACE_LEVEL, 'TRACE') -DEFAULT_SSH_DIRECTORY = '~/.ssh' +logging.addLevelName(TRACE_LEVEL, "TRACE") +DEFAULT_SSH_DIRECTORY = "~/.ssh" -_StreamServer = socketserver.UnixStreamServer if os.name == 'posix' \ - else socketserver.TCPServer +_StreamServer = socketserver.UnixStreamServer if os.name == "posix" else socketserver.TCPServer #: Path of optional ssh configuration file -DEFAULT_SSH_DIRECTORY = '~/.ssh' -SSH_CONFIG_FILE = os.path.join(DEFAULT_SSH_DIRECTORY, 'config') +DEFAULT_SSH_DIRECTORY = "~/.ssh" +SSH_CONFIG_FILE = os.path.join(DEFAULT_SSH_DIRECTORY, "config") ######################## # # @@ -77,14 +72,12 @@ def check_host(host): - assert isinstance(host, string_types), 'IP is not a string ({0})'.format( - type(host).__name__ - ) + assert isinstance(host, string_types), "IP is not a string ({0})".format(type(host).__name__) def check_port(port): - assert isinstance(port, int), 'PORT is not a number' - assert port >= 0, 'PORT < 0 ({0})'.format(port) + assert isinstance(port, int), "PORT is not a number" + assert port >= 0, "PORT < 0 ({0})".format(port) def check_address(address): @@ -110,15 +103,12 @@ def check_address(address): check_host(address[0]) check_port(address[1]) elif isinstance(address, string_types): - if os.name != 'posix': - raise ValueError('Platform does not support UNIX domain sockets') - if not (os.path.exists(address) or - os.access(os.path.dirname(address), os.W_OK)): - raise ValueError('ADDRESS not a valid socket domain socket ({0})' - .format(address)) + if os.name != "posix": + raise ValueError("Platform does not support UNIX domain sockets") + if not (os.path.exists(address) or os.access(os.path.dirname(address), os.W_OK)): + raise ValueError("ADDRESS not a valid socket domain socket ({0})".format(address)) else: - raise ValueError('ADDRESS is not a tuple, string, or character buffer ' - '({0})'.format(type(address).__name__)) + raise ValueError("ADDRESS is not a tuple, string, or character buffer ({0})".format(type(address).__name__)) def check_addresses(address_list, is_remote=False): @@ -148,18 +138,14 @@ def check_addresses(address_list, is_remote=False): >>> check_addresses([('127.0.0.1', 22), ('127.0.0.1', 2222)]) """ assert all(isinstance(x, (tuple, string_types)) for x in address_list) - if (is_remote and any(isinstance(x, string_types) for x in address_list)): - raise AssertionError('UNIX domain sockets not allowed for remote' - 'addresses') + if is_remote and any(isinstance(x, string_types) for x in address_list): + raise AssertionError("UNIX domain sockets not allowed for remoteaddresses") for address in address_list: check_address(address) -def create_logger(logger=None, - loglevel=None, - capture_warnings=True, - add_paramiko_handler=True): +def create_logger(logger=None, loglevel=None, capture_warnings=True, add_paramiko_handler=True): """ Attach or create a new logger and add a console handler if not present @@ -181,8 +167,6 @@ def create_logger(logger=None, Default: True - .. note:: ignored in python 2.6 - add_paramiko_handler (boolean): Whether or not add a console handler for ``paramiko.transport``'s logger if no handler present @@ -191,15 +175,11 @@ def create_logger(logger=None, Return: :class:`logging.Logger` """ - logger = logger or logging.getLogger( - 'sshtunnel.SSHTunnelForwarder' - ) + logger = logger or logging.getLogger("sshtunnel.SSHTunnelForwarder") if not any(isinstance(x, logging.Handler) for x in logger.handlers): logger.setLevel(loglevel or DEFAULT_LOGLEVEL) console_handler = logging.StreamHandler() - _add_handler(logger, - handler=console_handler, - loglevel=loglevel or DEFAULT_LOGLEVEL) + _add_handler(logger, handler=console_handler, loglevel=loglevel or DEFAULT_LOGLEVEL) if loglevel: # override if loglevel was set logger.setLevel(loglevel) for handler in logger.handlers: @@ -208,9 +188,9 @@ def create_logger(logger=None, if add_paramiko_handler: _check_paramiko_handlers(logger=logger) - if capture_warnings and sys.version_info >= (2, 7): + if capture_warnings: logging.captureWarnings(True) - pywarnings = logging.getLogger('py.warnings') + pywarnings = logging.getLogger("py.warnings") pywarnings.handlers.extend(logger.handlers) return logger @@ -221,13 +201,10 @@ def _add_handler(logger, handler=None, loglevel=None): """ handler.setLevel(loglevel or DEFAULT_LOGLEVEL) if handler.level <= logging.DEBUG: - _fmt = '%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/' \ - '%(lineno)04d@%(module)-10.9s| %(message)s' + _fmt = "%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/%(lineno)04d@%(module)-10.9s| %(message)s" handler.setFormatter(logging.Formatter(_fmt)) else: - handler.setFormatter(logging.Formatter( - '%(asctime)s| %(levelname)-8s| %(message)s' - )) + handler.setFormatter(logging.Formatter("%(asctime)s| %(levelname)-8s| %(message)s")) logger.addHandler(handler) @@ -235,34 +212,33 @@ def _check_paramiko_handlers(logger=None): """ Add a console handler for paramiko.transport's logger if not present """ - paramiko_logger = logging.getLogger('paramiko.transport') + paramiko_logger = logging.getLogger("paramiko.transport") if not paramiko_logger.handlers: if logger: paramiko_logger.handlers = logger.handlers else: console_handler = logging.StreamHandler() console_handler.setFormatter( - logging.Formatter('%(asctime)s | %(levelname)-8s| PARAMIKO: ' - '%(lineno)03d@%(module)-10s| %(message)s') + logging.Formatter("%(asctime)s | %(levelname)-8s| PARAMIKO: %(lineno)03d@%(module)-10s| %(message)s") ) paramiko_logger.addHandler(console_handler) def address_to_str(address): if isinstance(address, tuple): - return '{0[0]}:{0[1]}'.format(address) + return "{0[0]}:{0[1]}".format(address) return str(address) def _remove_none_values(dictionary): - """ Remove dictionary keys whose value is None """ - return list(map(dictionary.pop, - [i for i in dictionary if dictionary[i] is None])) + """Remove dictionary keys whose value is None""" + return list(map(dictionary.pop, [i for i in dictionary if dictionary[i] is None])) def generate_random_string(length): letters = string.ascii_letters + string.digits - return ''.join(random.choice(letters) for _ in range(length)) + return "".join(random.choice(letters) for _ in range(length)) + ######################## # # @@ -272,17 +248,18 @@ def generate_random_string(length): class BaseSSHTunnelForwarderError(Exception): - """ Exception raised by :class:`SSHTunnelForwarder` errors """ + """Exception raised by :class:`SSHTunnelForwarder` errors""" def __init__(self, *args, **kwargs): - self.value = kwargs.pop('value', args[0] if args else '') + self.value = kwargs.pop("value", args[0] if args else "") def __str__(self): return self.value class HandlerSSHTunnelForwarderError(BaseSSHTunnelForwarderError): - """ Exception for Tunnel forwarder errors """ + """Exception for Tunnel forwarder errors""" + pass @@ -294,7 +271,8 @@ class HandlerSSHTunnelForwarderError(BaseSSHTunnelForwarderError): class _ForwardHandler(socketserver.BaseRequestHandler): - """ Base handler for tunnel connections """ + """Base handler for tunnel connections""" + remote_address = None ssh_transport = None logger = None @@ -306,59 +284,42 @@ def _redirect(self, chan): if self.request in rqst: data = self.request.recv(16384) if not data: - self.logger.log( - TRACE_LEVEL, - '>>> OUT {0} recv empty data >>>'.format(self.info) - ) + self.logger.log(TRACE_LEVEL, ">>> OUT {0} recv empty data >>>".format(self.info)) break if self.logger.isEnabledFor(TRACE_LEVEL): self.logger.log( TRACE_LEVEL, - '>>> OUT {0} send to {1}: {2} >>>'.format( - self.info, - self.remote_address, - hexlify(data) - ) + ">>> OUT {0} send to {1}: {2} >>>".format(self.info, self.remote_address, hexlify(data)), ) chan.sendall(data) if chan in rqst: # else if not chan.recv_ready(): - self.logger.log( - TRACE_LEVEL, - '<<< IN {0} recv is not ready <<<'.format(self.info) - ) + self.logger.log(TRACE_LEVEL, "<<< IN {0} recv is not ready <<<".format(self.info)) break data = chan.recv(16384) if self.logger.isEnabledFor(TRACE_LEVEL): hex_data = hexlify(data) - self.logger.log( - TRACE_LEVEL, - '<<< IN {0} recv: {1} <<<'.format(self.info, hex_data) - ) + self.logger.log(TRACE_LEVEL, "<<< IN {0} recv: {1} <<<".format(self.info, hex_data)) self.request.sendall(data) def handle(self): uid = generate_random_string(5) - self.info = '#{0} <-- {1}'.format(uid, self.client_address or - self.server.local_address) + self.info = "#{0} <-- {1}".format(uid, self.client_address or self.server.local_address) src_address = self.request.getpeername() if not isinstance(src_address, tuple): - src_address = ('dummy', 12345) + src_address = ("dummy", 12345) try: chan = self.ssh_transport.open_channel( - kind='direct-tcpip', - dest_addr=self.remote_address, - src_addr=src_address, - timeout=TUNNEL_TIMEOUT + kind="direct-tcpip", dest_addr=self.remote_address, src_addr=src_address, timeout=TUNNEL_TIMEOUT ) except Exception as e: # pragma: no cover - msg_tupe = 'ssh ' if isinstance(e, paramiko.SSHException) else '' - exc_msg = 'open new channel {0}error: {1}'.format(msg_tupe, e) - log_msg = '{0} {1}'.format(self.info, exc_msg) + msg_tupe = "ssh " if isinstance(e, paramiko.SSHException) else "" + exc_msg = "open new channel {0}error: {1}".format(msg_tupe, e) + log_msg = "{0} {1}".format(self.info, exc_msg) self.logger.log(TRACE_LEVEL, log_msg) raise HandlerSSHTunnelForwarderError(exc_msg) - self.logger.log(TRACE_LEVEL, '{0} connected'.format(self.info)) + self.logger.log(TRACE_LEVEL, "{0} connected".format(self.info)) try: self._redirect(chan) except socket.error: @@ -366,25 +327,24 @@ def handle(self): # exception. It was seen that a 3way FIN is processed later on, so # no need to make an ordered close of the connection here or raise # the exception beyond this point... - self.logger.log(TRACE_LEVEL, '{0} sending RST'.format(self.info)) + self.logger.log(TRACE_LEVEL, "{0} sending RST".format(self.info)) except Exception as e: - self.logger.log(TRACE_LEVEL, - '{0} error: {1}'.format(self.info, repr(e))) + self.logger.log(TRACE_LEVEL, "{0} error: {1}".format(self.info, repr(e))) finally: chan.close() self.request.close() - self.logger.log(TRACE_LEVEL, - '{0} connection closed.'.format(self.info)) + self.logger.log(TRACE_LEVEL, "{0} connection closed.".format(self.info)) class _ForwardServer(socketserver.TCPServer): # Not Threading """ Non-threading version of the forward server """ + allow_reuse_address = True # faster rebinding def __init__(self, *args, **kwargs): - logger = kwargs.pop('logger', None) + logger = kwargs.pop("logger", None) self.logger = logger or create_logger() self.tunnel_ok = queue.Queue(1) socketserver.TCPServer.__init__(self, *args, **kwargs) @@ -393,16 +353,18 @@ def handle_error(self, request, client_address): (exc_class, exc, tb) = sys.exc_info() local_side = request.getsockname() remote_side = self.remote_address - self.logger.error('Could not establish connection from local {0} ' - 'to remote {1} side of the tunnel: {2}' - .format(local_side, remote_side, exc)) + self.logger.error( + "Could not establish connection from local {0} to remote {1} side of the tunnel: {2}".format( + local_side, remote_side, exc + ) + ) try: self.tunnel_ok.put(False, block=False, timeout=0.1) except queue.Full: # wait untill tunnel_ok.get is called pass except exc: - self.logger.error('unexpected internal error: {0}'.format(exc)) + self.logger.error("unexpected internal error: {0}".format(exc)) @property def local_address(self): @@ -433,6 +395,7 @@ class _ThreadingForwardServer(socketserver.ThreadingMixIn, _ForwardServer): """ Allow concurrent connections to each tunnel """ + # If True, cleanly stop threads created by ThreadingMixIn when quitting # This value is overrides by SSHTunnelForwarder.daemon_forward_servers daemon_threads = _DAEMON @@ -444,7 +407,7 @@ class _StreamForwardServer(_StreamServer): """ def __init__(self, *args, **kwargs): - logger = kwargs.pop('logger', None) + logger = kwargs.pop("logger", None) self.logger = logger or create_logger() self.tunnel_ok = queue.Queue(1) _StreamServer.__init__(self, *args, **kwargs) @@ -474,11 +437,11 @@ def remote_port(self): return self.RequestHandlerClass.remote_address[1] -class _ThreadingStreamForwardServer(socketserver.ThreadingMixIn, - _StreamForwardServer): +class _ThreadingStreamForwardServer(socketserver.ThreadingMixIn, _StreamForwardServer): """ Allow concurrent connections to each tunnel """ + # If True, cleanly stop threads created by ThreadingMixIn when quitting # This value is overrides by SSHTunnelForwarder.daemon_forward_servers daemon_threads = _DAEMON @@ -741,6 +704,7 @@ class SSHTunnelForwarder(object): .. versionadded:: 0.1.0 """ + skip_tunnel_checkup = True # This option affects the `ForwardServer` and all his threads daemon_forward_servers = _DAEMON #: flag tunnel threads in daemon mode @@ -765,10 +729,12 @@ def local_is_up(self, target): try: check_address(target) except ValueError: - self.logger.warning('Target must be a tuple (IP, port), where IP ' - 'is a string (i.e. "192.168.0.1") and port is ' - 'an integer (i.e. 40000). Alternatively ' - 'target can be a valid UNIX domain socket.') + self.logger.warning( + "Target must be a tuple (IP, port), where IP " + 'is a string (i.e. "192.168.0.1") and port is ' + "an integer (i.e. 40000). Alternatively " + "target can be a valid UNIX domain socket." + ) return False self.check_tunnels() @@ -789,11 +755,11 @@ def check_tunnels(self): self.skip_tunnel_checkup = skip_tunnel_checkup # roll it back def _check_tunnel(self, _srv): - """ Check if tunnel is already established """ + """Check if tunnel is already established""" if self.skip_tunnel_checkup: self.tunnel_is_up[_srv.local_address] = True return - self.logger.info('Checking tunnel to: {0}'.format(_srv.remote_address)) + self.logger.info("Checking tunnel to: {0}".format(_srv.remote_address)) if isinstance(_srv.local_address, string_types): # UNIX stream s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) else: @@ -801,25 +767,16 @@ def _check_tunnel(self, _srv): s.settimeout(TUNNEL_TIMEOUT) try: # Windows raises WinError 10049 if trying to connect to 0.0.0.0 - connect_to = ('127.0.0.1', _srv.local_port) \ - if _srv.local_host == '0.0.0.0' else _srv.local_address + connect_to = ("127.0.0.1", _srv.local_port) if _srv.local_host == "0.0.0.0" else _srv.local_address s.connect(connect_to) - self.tunnel_is_up[_srv.local_address] = _srv.tunnel_ok.get( - timeout=TUNNEL_TIMEOUT * 1.1 - ) - self.logger.debug( - 'Tunnel to {0} is DOWN'.format(_srv.remote_address) - ) + self.tunnel_is_up[_srv.local_address] = _srv.tunnel_ok.get(timeout=TUNNEL_TIMEOUT * 1.1) + self.logger.debug("Tunnel to {0} is DOWN".format(_srv.remote_address)) except socket.error: - self.logger.debug( - 'Tunnel to {0} is DOWN'.format(_srv.remote_address) - ) + self.logger.debug("Tunnel to {0} is DOWN".format(_srv.remote_address)) self.tunnel_is_up[_srv.local_address] = False except queue.Empty: - self.logger.debug( - 'Tunnel to {0} is UP'.format(_srv.remote_address) - ) + self.logger.debug("Tunnel to {0} is UP".format(_srv.remote_address)) self.tunnel_is_up[_srv.local_address] = True finally: s.close() @@ -828,18 +785,19 @@ def _make_ssh_forward_handler_class(self, remote_address_): """ Make SSH Handler class """ + class Handler(_ForwardHandler): remote_address = remote_address_ ssh_transport = self._transport logger = self.logger + return Handler def _make_ssh_forward_server_class(self, remote_address_): return _ThreadingForwardServer if self._threaded else _ForwardServer def _make_stream_ssh_forward_server_class(self, remote_address_): - return _ThreadingStreamForwardServer if self._threaded \ - else _StreamForwardServer + return _ThreadingStreamForwardServer if self._threaded else _StreamForwardServer def _make_ssh_forward_server(self, remote_address, local_bind_address): """ @@ -847,9 +805,11 @@ def _make_ssh_forward_server(self, remote_address, local_bind_address): """ _Handler = self._make_ssh_forward_handler_class(remote_address) try: - forward_maker_class = self._make_stream_ssh_forward_server_class \ - if isinstance(local_bind_address, string_types) \ + forward_maker_class = ( + self._make_stream_ssh_forward_server_class + if isinstance(local_bind_address, string_types) else self._make_ssh_forward_server_class + ) _Server = forward_maker_class(remote_address) ssh_forward_server = _Server( local_bind_address, @@ -864,45 +824,42 @@ def _make_ssh_forward_server(self, remote_address, local_bind_address): else: self._raise( BaseSSHTunnelForwarderError, - 'Problem setting up ssh {0} <> {1} forwarder. You can ' - 'suppress this exception by using the `mute_exceptions`' - 'argument'.format(address_to_str(local_bind_address), - address_to_str(remote_address)) + "Problem setting up ssh {0} <> {1} forwarder. You can " + "suppress this exception by using the `mute_exceptions`" + "argument".format(address_to_str(local_bind_address), address_to_str(remote_address)), ) except IOError: self._raise( BaseSSHTunnelForwarderError, - "Couldn't open tunnel {0} <> {1} might be in use or " - "destination not reachable".format( - address_to_str(local_bind_address), - address_to_str(remote_address) - ) + "Couldn't open tunnel {0} <> {1} might be in use or destination not reachable".format( + address_to_str(local_bind_address), address_to_str(remote_address) + ), ) def __init__( - self, - ssh_address_or_host=None, - ssh_config_file=SSH_CONFIG_FILE, - ssh_host_key=None, - ssh_password=None, - ssh_pkey=None, - ssh_private_key_password=None, - ssh_proxy=None, - ssh_proxy_enabled=True, - ssh_username=None, - local_bind_address=None, - local_bind_addresses=None, - logger=None, - mute_exceptions=False, - remote_bind_address=None, - remote_bind_addresses=None, - set_keepalive=5.0, - threaded=True, # old version False - compression=None, - allow_agent=True, # look for keys from an SSH agent - host_pkey_directories=None, # look for keys in ~/.ssh - *args, - **kwargs # for backwards compatibility + self, + ssh_address_or_host=None, + ssh_config_file=SSH_CONFIG_FILE, + ssh_host_key=None, + ssh_password=None, + ssh_pkey=None, + ssh_private_key_password=None, + ssh_proxy=None, + ssh_proxy_enabled=True, + ssh_username=None, + local_bind_address=None, + local_bind_addresses=None, + logger=None, + mute_exceptions=False, + remote_bind_address=None, + remote_bind_addresses=None, + set_keepalive=5.0, + threaded=True, # old version False + compression=None, + allow_agent=True, # look for keys from an SSH agent + host_pkey_directories=None, # look for keys in ~/.ssh + *args, + **kwargs, # for backwards compatibility ): self.logger = logger or create_logger() @@ -913,54 +870,48 @@ def __init__( self._threaded = threaded self.is_alive = False # Check if deprecated arguments ssh_address or ssh_host were used - for deprecated_argument in ['ssh_address', 'ssh_host']: - ssh_address_or_host = self._process_deprecated(ssh_address_or_host, - deprecated_argument, - kwargs) + for deprecated_argument in ["ssh_address", "ssh_host"]: + ssh_address_or_host = self._process_deprecated(ssh_address_or_host, deprecated_argument, kwargs) # other deprecated arguments - ssh_pkey = self._process_deprecated(ssh_pkey, - 'ssh_private_key', - kwargs) + ssh_pkey = self._process_deprecated(ssh_pkey, "ssh_private_key", kwargs) - self._raise_fwd_exc = self._process_deprecated( - None, - 'raise_exception_if_any_forwarder_have_a_problem', - kwargs) or not mute_exceptions + self._raise_fwd_exc = ( + self._process_deprecated(None, "raise_exception_if_any_forwarder_have_a_problem", kwargs) + or not mute_exceptions + ) if isinstance(ssh_address_or_host, tuple): check_address(ssh_address_or_host) (ssh_host, ssh_port) = ssh_address_or_host else: ssh_host = ssh_address_or_host - ssh_port = kwargs.pop('ssh_port', None) + ssh_port = kwargs.pop("ssh_port", None) if kwargs: - raise ValueError('Unknown arguments: {0}'.format(kwargs)) + raise ValueError("Unknown arguments: {0}".format(kwargs)) # remote binds - self._remote_binds = self._get_binds(remote_bind_address, - remote_bind_addresses, - is_remote=True) + self._remote_binds = self._get_binds(remote_bind_address, remote_bind_addresses, is_remote=True) # local binds - self._local_binds = self._get_binds(local_bind_address, - local_bind_addresses) - self._local_binds = self._consolidate_binds(self._local_binds, - self._remote_binds) - - (self.ssh_host, - self.ssh_username, - ssh_pkey, # still needs to go through _consolidate_auth - self.ssh_port, - self.ssh_proxy, - self.compression) = self._read_ssh_config( - ssh_host, - ssh_config_file, - ssh_username, - ssh_pkey, - ssh_port, - ssh_proxy if ssh_proxy_enabled else None, - compression, - self.logger + self._local_binds = self._get_binds(local_bind_address, local_bind_addresses) + self._local_binds = self._consolidate_binds(self._local_binds, self._remote_binds) + + ( + self.ssh_host, + self.ssh_username, + ssh_pkey, # still needs to go through _consolidate_auth + self.ssh_port, + self.ssh_proxy, + self.compression, + ) = self._read_ssh_config( + ssh_host, + ssh_config_file, + ssh_username, + ssh_pkey, + ssh_port, + ssh_proxy if ssh_proxy_enabled else None, + compression, + self.logger, ) (self.ssh_password, self.ssh_pkeys) = self._consolidate_auth( @@ -969,29 +920,29 @@ def __init__( ssh_pkey_password=ssh_private_key_password, allow_agent=allow_agent, host_pkey_directories=host_pkey_directories, - logger=self.logger + logger=self.logger, ) check_host(self.ssh_host) check_port(self.ssh_port) - self.logger.info("Connecting to gateway: {0}:{1} as user '{2}'" - .format(self.ssh_host, - self.ssh_port, - self.ssh_username)) + self.logger.info( + "Connecting to gateway: {0}:{1} as user '{2}'".format(self.ssh_host, self.ssh_port, self.ssh_username) + ) - self.logger.debug('Concurrent connections allowed: {0}' - .format(self._threaded)) + self.logger.debug("Concurrent connections allowed: {0}".format(self._threaded)) @staticmethod - def _read_ssh_config(ssh_host, - ssh_config_file, - ssh_username=None, - ssh_pkey=None, - ssh_port=None, - ssh_proxy=None, - compression=None, - logger=None): + def _read_ssh_config( + ssh_host, + ssh_config_file, + ssh_username=None, + ssh_pkey=None, + ssh_port=None, + ssh_proxy=None, + compression=None, + logger=None, + ): """ Read ssh_config_file and tries to look for user (ssh_username), identityfile (ssh_pkey), port (ssh_port) and proxycommand @@ -1004,49 +955,41 @@ def _read_ssh_config(ssh_host, # Try to read SSH_CONFIG_FILE try: # open the ssh config file - with open(os.path.expanduser(ssh_config_file), 'r') as f: + with open(os.path.expanduser(ssh_config_file), "r") as f: ssh_config.parse(f) # looks for information for the destination system hostname_info = ssh_config.lookup(ssh_host) # gather settings for user, port and identity file # last resort: use the 'login name' of the user - ssh_username = ( - ssh_username or - hostname_info.get('user') - ) - ssh_pkey = ( - ssh_pkey or - hostname_info.get('identityfile', [None])[0] - ) - ssh_host = hostname_info.get('hostname') - ssh_port = ssh_port or hostname_info.get('port') + ssh_username = ssh_username or hostname_info.get("user") + ssh_pkey = ssh_pkey or hostname_info.get("identityfile", [None])[0] + ssh_host = hostname_info.get("hostname") + ssh_port = ssh_port or hostname_info.get("port") - proxycommand = hostname_info.get('proxycommand') - ssh_proxy = ssh_proxy or (paramiko.ProxyCommand(proxycommand) if - proxycommand else None) + proxycommand = hostname_info.get("proxycommand") + ssh_proxy = ssh_proxy or (paramiko.ProxyCommand(proxycommand) if proxycommand else None) if compression is None: - compression = hostname_info.get('compression', '') - compression = True if compression.upper() == 'YES' else False + compression = hostname_info.get("compression", "") + compression = True if compression.upper() == "YES" else False except IOError: if logger: - logger.warning( - 'Could not read SSH configuration file: {0}' - .format(ssh_config_file) - ) + logger.warning("Could not read SSH configuration file: {0}".format(ssh_config_file)) except (AttributeError, TypeError): # ssh_config_file is None if logger: - logger.info('Skipping loading of ssh configuration file') + logger.info("Skipping loading of ssh configuration file") finally: - return (ssh_host, - ssh_username or getpass.getuser(), - ssh_pkey, - int(ssh_port) if ssh_port else 22, # fallback value - ssh_proxy, - compression) + return ( + ssh_host, + ssh_username or getpass.getuser(), + ssh_pkey, + int(ssh_port) if ssh_port else 22, # fallback value + ssh_proxy, + compression, + ) @staticmethod def get_agent_keys(logger=None): - """ Load public keys from any available SSH agent + """Load public keys from any available SSH agent Arguments: logger (Optional[logging.Logger]) @@ -1057,7 +1000,7 @@ def get_agent_keys(logger=None): paramiko_agent = paramiko.Agent() agent_keys = paramiko_agent.get_keys() if logger: - logger.info('{0} keys loaded from agent'.format(len(agent_keys))) + logger.info("{0} keys loaded from agent".format(len(agent_keys))) return list(agent_keys) @staticmethod @@ -1083,38 +1026,31 @@ def get_keys(logger=None, host_pkey_directories=None, allow_agent=False): Return: list """ - keys = SSHTunnelForwarder.get_agent_keys(logger=logger) \ - if allow_agent else [] + keys = SSHTunnelForwarder.get_agent_keys(logger=logger) if allow_agent else [] if host_pkey_directories is None: host_pkey_directories = [DEFAULT_SSH_DIRECTORY] - paramiko_key_types = {'rsa': paramiko.RSAKey, - 'dsa': paramiko.DSSKey, - 'ecdsa': paramiko.ECDSAKey} - if hasattr(paramiko, 'Ed25519Key'): - # NOQA: new in paramiko>=2.2: http://docs.paramiko.org/en/stable/api/keys.html#module-paramiko.ed25519key - paramiko_key_types['ed25519'] = paramiko.Ed25519Key + paramiko_key_types = { + "rsa": paramiko.RSAKey, + "ecdsa": paramiko.ECDSAKey, + "ed25519": paramiko.Ed25519Key, + } for directory in host_pkey_directories: for keytype in paramiko_key_types.keys(): - ssh_pkey_expanded = os.path.expanduser( - os.path.join(directory, 'id_{}'.format(keytype)) - ) + ssh_pkey_expanded = os.path.expanduser(os.path.join(directory, "id_{}".format(keytype))) try: if os.path.isfile(ssh_pkey_expanded): ssh_pkey = SSHTunnelForwarder.read_private_key_file( - pkey_file=ssh_pkey_expanded, - logger=logger, - key_type=paramiko_key_types[keytype] + pkey_file=ssh_pkey_expanded, logger=logger, key_type=paramiko_key_types[keytype] ) if ssh_pkey: keys.append(ssh_pkey) except OSError as exc: if logger: - logger.warning('Private key file {0} check error: {1}' - .format(ssh_pkey_expanded, exc)) + logger.warning("Private key file {0} check error: {1}".format(ssh_pkey_expanded, exc)) if logger: - logger.info('{0} key(s) loaded'.format(len(keys))) + logger.info("{0} key(s) loaded".format(len(keys))) return keys @staticmethod @@ -1125,18 +1061,19 @@ def _consolidate_binds(local_binds, remote_binds): """ count = len(remote_binds) - len(local_binds) if count < 0: - raise ValueError('Too many local bind addresses ' - '(local_bind_addresses > remote_bind_addresses)') - local_binds.extend([('0.0.0.0', 0) for x in range(count)]) + raise ValueError("Too many local bind addresses (local_bind_addresses > remote_bind_addresses)") + local_binds.extend([("0.0.0.0", 0) for x in range(count)]) return local_binds @staticmethod - def _consolidate_auth(ssh_password=None, - ssh_pkey=None, - ssh_pkey_password=None, - allow_agent=True, - host_pkey_directories=None, - logger=None): + def _consolidate_auth( + ssh_password=None, + ssh_pkey=None, + ssh_pkey_password=None, + allow_agent=True, + host_pkey_directories=None, + logger=None, + ): """ Get sure authentication information is in place. ``ssh_pkey`` may be of classes: @@ -1146,27 +1083,22 @@ def _consolidate_auth(ssh_password=None, """ ssh_loaded_pkeys = SSHTunnelForwarder.get_keys( - logger=logger, - host_pkey_directories=host_pkey_directories, - allow_agent=allow_agent + logger=logger, host_pkey_directories=host_pkey_directories, allow_agent=allow_agent ) if isinstance(ssh_pkey, string_types): ssh_pkey_expanded = os.path.expanduser(ssh_pkey) if os.path.exists(ssh_pkey_expanded): ssh_pkey = SSHTunnelForwarder.read_private_key_file( - pkey_file=ssh_pkey_expanded, - pkey_password=ssh_pkey_password or ssh_password, - logger=logger + pkey_file=ssh_pkey_expanded, pkey_password=ssh_pkey_password or ssh_password, logger=logger ) elif logger: - logger.warning('Private key file not found: {0}' - .format(ssh_pkey)) + logger.warning("Private key file not found: {0}".format(ssh_pkey)) if isinstance(ssh_pkey, paramiko.pkey.PKey): ssh_loaded_pkeys.insert(0, ssh_pkey) if not ssh_password and not ssh_loaded_pkeys: - raise ValueError('No password or public key available!') + raise ValueError("No password or public key available!") return (ssh_password, ssh_loaded_pkeys) def _raise(self, exception=BaseSSHTunnelForwarderError, reason=None): @@ -1176,13 +1108,13 @@ def _raise(self, exception=BaseSSHTunnelForwarderError, reason=None): self.logger.error(repr(exception(reason))) def _get_transport(self): - """ Return the SSH transport to the remote gateway """ + """Return the SSH transport to the remote gateway""" if self.ssh_proxy: if isinstance(self.ssh_proxy, paramiko.proxy.ProxyCommand): proxy_repr = repr(self.ssh_proxy.cmd[1]) else: proxy_repr = repr(self.ssh_proxy) - self.logger.debug('Connecting via proxy: {0}'.format(proxy_repr)) + self.logger.debug("Connecting via proxy: {0}".format(proxy_repr)) _socket = self.ssh_proxy else: _socket = (self.ssh_host, self.ssh_port) @@ -1201,8 +1133,7 @@ def _get_transport(self): if isinstance(sock, socket.socket): sock_timeout = sock.gettimeout() sock_info = repr((sock.family, sock.type, sock.proto)) - self.logger.debug('Transport socket info: {0}, timeout={1}' - .format(sock_info, sock_timeout)) + self.logger.debug("Transport socket info: {0}, timeout={1}".format(sock_info, sock_timeout)) return transport def _create_tunnels(self): @@ -1213,42 +1144,45 @@ def _create_tunnels(self): try: self._connect_to_gateway() except socket.gaierror: # raised by paramiko.Transport - msg = 'Could not resolve IP address for {0}, aborting!' \ - .format(self.ssh_host) + msg = "Could not resolve IP address for {0}, aborting!".format(self.ssh_host) self.logger.error(msg) return except (paramiko.SSHException, socket.error) as e: - template = 'Could not connect to gateway {0}:{1} : {2}' + template = "Could not connect to gateway {0}:{1} : {2}" msg = template.format(self.ssh_host, self.ssh_port, e.args[0]) self.logger.error(msg) return - for (rem, loc) in zip(self._remote_binds, self._local_binds): + for rem, loc in zip(self._remote_binds, self._local_binds): try: self._make_ssh_forward_server(rem, loc) except BaseSSHTunnelForwarderError as e: - msg = 'Problem setting SSH Forwarder up: {0}'.format(e.value) + msg = "Problem setting SSH Forwarder up: {0}".format(e.value) self.logger.error(msg) @staticmethod def _get_binds(bind_address, bind_addresses, is_remote=False): - addr_kind = 'remote' if is_remote else 'local' + addr_kind = "remote" if is_remote else "local" if not bind_address and not bind_addresses: if is_remote: - raise ValueError("No {0} bind addresses specified. Use " - "'{0}_bind_address' or '{0}_bind_addresses'" - " argument".format(addr_kind)) + raise ValueError( + "No {0} bind addresses specified. Use '{0}_bind_address' or '{0}_bind_addresses' argument".format( + addr_kind + ) + ) else: return [] elif bind_address and bind_addresses: - raise ValueError("You can't use both '{0}_bind_address' and " - "'{0}_bind_addresses' arguments. Use one of " - "them.".format(addr_kind)) + raise ValueError( + "You can't use both '{0}_bind_address' and '{0}_bind_addresses' arguments. Use one of them.".format( + addr_kind + ) + ) if bind_address: bind_addresses = [bind_address] if not is_remote: # Add random port if missing in local bind - for (i, local_bind) in enumerate(bind_addresses): + for i, local_bind in enumerate(bind_addresses): if isinstance(local_bind, tuple) and len(local_bind) == 1: bind_addresses[i] = (local_bind[0], 0) check_addresses(bind_addresses, is_remote) @@ -1260,33 +1194,30 @@ def _process_deprecated(attrib, deprecated_attrib, kwargs): Processes optional deprecate arguments """ if deprecated_attrib not in _DEPRECATIONS: - raise ValueError('{0} not included in deprecations list' - .format(deprecated_attrib)) + raise ValueError("{0} not included in deprecations list".format(deprecated_attrib)) if deprecated_attrib in kwargs: - warnings.warn("'{0}' is DEPRECATED use '{1}' instead" - .format(deprecated_attrib, - _DEPRECATIONS[deprecated_attrib]), - DeprecationWarning) + warnings.warn( + "'{0}' is DEPRECATED use '{1}' instead".format(deprecated_attrib, _DEPRECATIONS[deprecated_attrib]), + DeprecationWarning, + ) if attrib: - raise ValueError("You can't use both '{0}' and '{1}'. " - "Please only use one of them" - .format(deprecated_attrib, - _DEPRECATIONS[deprecated_attrib])) + raise ValueError( + "You can't use both '{0}' and '{1}'. Please only use one of them".format( + deprecated_attrib, _DEPRECATIONS[deprecated_attrib] + ) + ) else: return kwargs.pop(deprecated_attrib) return attrib @staticmethod - def read_private_key_file(pkey_file, - pkey_password=None, - key_type=None, - logger=None): + def read_private_key_file(pkey_file, pkey_password=None, key_type=None, logger=None): """ Get SSH Public key from a private key file, given an optional password Arguments: pkey_file (str): - File containing a private key (RSA, DSS or ECDSA) + File containing a private key (RSA, ECDSA or Ed25519) Keyword Arguments: pkey_password (Optional[str]): Password to decrypt the private key @@ -1295,54 +1226,44 @@ def read_private_key_file(pkey_file, paramiko.Pkey """ ssh_pkey = None - key_types = (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey) - if hasattr(paramiko, 'Ed25519Key'): - # NOQA: new in paramiko>=2.2: http://docs.paramiko.org/en/stable/api/keys.html#module-paramiko.ed25519key - key_types += (paramiko.Ed25519Key, ) + key_types = (paramiko.RSAKey, paramiko.ECDSAKey, paramiko.Ed25519Key) for pkey_class in (key_type,) if key_type else key_types: try: - ssh_pkey = pkey_class.from_private_key_file( - pkey_file, - password=pkey_password - ) + ssh_pkey = pkey_class.from_private_key_file(pkey_file, password=pkey_password) if logger: - logger.debug('Private key file ({0}, {1}) successfully ' - 'loaded'.format(pkey_file, pkey_class)) + logger.debug("Private key file ({0}, {1}) successfully loaded".format(pkey_file, pkey_class)) break except paramiko.PasswordRequiredException: if logger: - logger.error('Password is required for key {0}' - .format(pkey_file)) + logger.error("Password is required for key {0}".format(pkey_file)) break except paramiko.SSHException: if logger: - logger.debug('Private key file ({0}) could not be loaded ' - 'as type {1} or bad password' - .format(pkey_file, pkey_class)) + logger.debug( + "Private key file ({0}) could not be loaded as type {1} or bad password".format( + pkey_file, pkey_class + ) + ) return ssh_pkey def start(self): - """ Start the SSH tunnels """ + """Start the SSH tunnels""" if self.is_alive: - self.logger.warning('Already started!') + self.logger.warning("Already started!") return self._create_tunnels() if not self.is_active: - self._raise(BaseSSHTunnelForwarderError, - reason='Could not establish session to SSH gateway') + self._raise(BaseSSHTunnelForwarderError, reason="Could not establish session to SSH gateway") for _srv in self._server_list: thread = threading.Thread( - target=self._serve_forever_wrapper, - args=(_srv, ), - name='Srv-{0}'.format(address_to_str(_srv.local_port)) + target=self._serve_forever_wrapper, args=(_srv,), name="Srv-{0}".format(address_to_str(_srv.local_port)) ) thread.daemon = self.daemon_forward_servers thread.start() self._check_tunnel(_srv) self.is_alive = any(self.tunnel_is_up.values()) if not self.is_alive: - self._raise(HandlerSSHTunnelForwarderError, - 'An error occurred while opening tunnels.') + self._raise(HandlerSSHTunnelForwarderError, "An error occurred while opening tunnels.") def stop(self, force=False): """ @@ -1371,21 +1292,19 @@ def stop(self, force=False): Handle these scenarios with :attr:`.tunnel_is_up`: if False, server ``shutdown()`` will be skipped on that tunnel """ - self.logger.info('Closing all open connections...') - opened_address_text = ', '.join( - (address_to_str(k.local_address) for k in self._server_list) - ) or 'None' - self.logger.debug('Listening tunnels: ' + opened_address_text) + self.logger.info("Closing all open connections...") + opened_address_text = ", ".join((address_to_str(k.local_address) for k in self._server_list)) or "None" + self.logger.debug("Listening tunnels: " + opened_address_text) self._stop_transport(force=force) self._server_list = [] # reset server list self.tunnel_is_up = {} # reset tunnel status def close(self): - """ Stop the an active tunnel, alias to :meth:`.stop` """ + """Stop the an active tunnel, alias to :meth:`.stop`""" self.stop() def restart(self): - """ Restart connection to the gateway and tunnels """ + """Restart connection to the gateway and tunnels""" self.stop() self.start() @@ -1397,69 +1316,64 @@ def _connect_to_gateway(self): - As last resort, try with a provided password """ for key in self.ssh_pkeys: - self.logger.debug('Trying to log in with key: {0}' - .format(hexlify(key.get_fingerprint()))) + self.logger.debug("Trying to log in with key: {0}".format(hexlify(key.get_fingerprint()))) try: self._transport = self._get_transport() - self._transport.connect(hostkey=self.ssh_host_key, - username=self.ssh_username, - pkey=key) + self._transport.connect(hostkey=self.ssh_host_key, username=self.ssh_username, pkey=key) if self._transport.is_alive: return except paramiko.AuthenticationException: - self.logger.debug('Authentication error') + self.logger.debug("Authentication error") self._stop_transport() if self.ssh_password: # avoid conflict using both pass and pkey - self.logger.debug('Trying to log in with password: {0}' - .format('*' * len(self.ssh_password))) + self.logger.debug("Trying to log in with password: {0}".format("*" * len(self.ssh_password))) try: self._transport = self._get_transport() - self._transport.connect(hostkey=self.ssh_host_key, - username=self.ssh_username, - password=self.ssh_password) + self._transport.connect( + hostkey=self.ssh_host_key, username=self.ssh_username, password=self.ssh_password + ) if self._transport.is_alive: return except paramiko.AuthenticationException: - self.logger.debug('Authentication error') + self.logger.debug("Authentication error") self._stop_transport() - self.logger.error('Could not open connection to gateway') + self.logger.error("Could not open connection to gateway") def _serve_forever_wrapper(self, _srv, poll_interval=0.1): """ Wrapper for the server created for a SSH forward """ - self.logger.info('Opening tunnel: {0} <> {1}'.format( - address_to_str(_srv.local_address), - address_to_str(_srv.remote_address)) + self.logger.info( + "Opening tunnel: {0} <> {1}".format(address_to_str(_srv.local_address), address_to_str(_srv.remote_address)) ) _srv.serve_forever(poll_interval) # blocks until finished - self.logger.info('Tunnel: {0} <> {1} released'.format( - address_to_str(_srv.local_address), - address_to_str(_srv.remote_address)) + self.logger.info( + "Tunnel: {0} <> {1} released".format( + address_to_str(_srv.local_address), address_to_str(_srv.remote_address) + ) ) def _stop_transport(self, force=False): - """ Close the underlying transport when nothing more is needed """ + """Close the underlying transport when nothing more is needed""" try: self._check_is_started() - except (BaseSSHTunnelForwarderError, - HandlerSSHTunnelForwarderError) as e: + except (BaseSSHTunnelForwarderError, HandlerSSHTunnelForwarderError) as e: self.logger.warning(e) if force and self.is_active: # don't wait connections - self.logger.info('Closing ssh transport') + self.logger.info("Closing ssh transport") self._transport.close() self._transport.stop_thread() for _srv in self._server_list: - status = 'up' if self.tunnel_is_up[_srv.local_address] else 'down' - self.logger.info('Shutting down tunnel: {0} <> {1} ({2})'.format( - address_to_str(_srv.local_address), - address_to_str(_srv.remote_address), - status - )) + status = "up" if self.tunnel_is_up[_srv.local_address] else "down" + self.logger.info( + "Shutting down tunnel: {0} <> {1} ({2})".format( + address_to_str(_srv.local_address), address_to_str(_srv.remote_address), status + ) + ) _srv.shutdown() _srv.server_close() # clean up the UNIX domain socket if we're using one @@ -1467,23 +1381,20 @@ def _stop_transport(self, force=False): try: os.unlink(_srv.local_address) except Exception as e: - self.logger.error('Unable to unlink socket {0}: {1}' - .format(_srv.local_address, repr(e))) + self.logger.error("Unable to unlink socket {0}: {1}".format(_srv.local_address, repr(e))) self.is_alive = False if self.is_active: - self.logger.info('Closing ssh transport') + self.logger.info("Closing ssh transport") self._transport.close() self._transport.stop_thread() - self.logger.debug('Transport is closed') + self.logger.debug("Transport is closed") @property def local_bind_port(self): # BACKWARDS COMPATIBILITY self._check_is_started() if len(self._server_list) != 1: - raise BaseSSHTunnelForwarderError( - 'Use .local_bind_ports property for more than one tunnel' - ) + raise BaseSSHTunnelForwarderError("Use .local_bind_ports property for more than one tunnel") return self.local_bind_ports[0] @property @@ -1491,9 +1402,7 @@ def local_bind_host(self): # BACKWARDS COMPATIBILITY self._check_is_started() if len(self._server_list) != 1: - raise BaseSSHTunnelForwarderError( - 'Use .local_bind_hosts property for more than one tunnel' - ) + raise BaseSSHTunnelForwarderError("Use .local_bind_hosts property for more than one tunnel") return self.local_bind_hosts[0] @property @@ -1501,9 +1410,7 @@ def local_bind_address(self): # BACKWARDS COMPATIBILITY self._check_is_started() if len(self._server_list) != 1: - raise BaseSSHTunnelForwarderError( - 'Use .local_bind_addresses property for more than one tunnel' - ) + raise BaseSSHTunnelForwarderError("Use .local_bind_addresses property for more than one tunnel") return self.local_bind_addresses[0] @property @@ -1512,8 +1419,7 @@ def local_bind_ports(self): Return a list containing the ports of local side of the TCP tunnels """ self._check_is_started() - return [_server.local_port for _server in self._server_list if - _server.local_port is not None] + return [_server.local_port for _server in self._server_list if _server.local_port is not None] @property def local_bind_hosts(self): @@ -1521,8 +1427,7 @@ def local_bind_hosts(self): Return a list containing the IP addresses listening for the tunnels """ self._check_is_started() - return [_server.local_host for _server in self._server_list if - _server.local_host is not None] + return [_server.local_host for _server in self._server_list if _server.local_host is not None] @property def local_bind_addresses(self): @@ -1537,68 +1442,70 @@ def tunnel_bindings(self): """ Return a dictionary containing the active local<>remote tunnel_bindings """ - return dict((_server.remote_address, _server.local_address) for - _server in self._server_list if - self.tunnel_is_up[_server.local_address]) + return dict( + (_server.remote_address, _server.local_address) + for _server in self._server_list + if self.tunnel_is_up[_server.local_address] + ) @property def is_active(self): - """ Return True if the underlying SSH transport is up """ - if ( - '_transport' in self.__dict__ and - self._transport.is_active() - ): + """Return True if the underlying SSH transport is up""" + if "_transport" in self.__dict__ and self._transport.is_active(): return True return False def _check_is_started(self): if not self.is_active: # underlying transport not alive - msg = 'Server is not started. Please .start() first!' + msg = "Server is not started. Please .start() first!" raise BaseSSHTunnelForwarderError(msg) if not self.is_alive: - msg = 'Tunnels are not started. Please .start() first!' + msg = "Tunnels are not started. Please .start() first!" raise HandlerSSHTunnelForwarderError(msg) def __str__(self): credentials = { - 'password': self.ssh_password, - 'pkeys': [(key.get_name(), hexlify(key.get_fingerprint())) - for key in self.ssh_pkeys] - if any(self.ssh_pkeys) else None + "password": self.ssh_password, + "pkeys": [(key.get_name(), hexlify(key.get_fingerprint())) for key in self.ssh_pkeys] + if any(self.ssh_pkeys) + else None, } _remove_none_values(credentials) - template = os.linesep.join(['{0} object', - 'ssh gateway: {1}:{2}', - 'proxy: {3}', - 'username: {4}', - 'authentication: {5}', - 'hostkey: {6}', - 'status: {7}started', - 'keepalive messages: {8}', - 'tunnel connection check: {9}', - 'concurrent connections: {10}allowed', - 'compression: {11}requested', - 'logging level: {12}', - 'local binds: {13}', - 'remote binds: {14}']) - return (template.format( + template = os.linesep.join( + [ + "{0} object", + "ssh gateway: {1}:{2}", + "proxy: {3}", + "username: {4}", + "authentication: {5}", + "hostkey: {6}", + "status: {7}started", + "keepalive messages: {8}", + "tunnel connection check: {9}", + "concurrent connections: {10}allowed", + "compression: {11}requested", + "logging level: {12}", + "local binds: {13}", + "remote binds: {14}", + ] + ) + return template.format( self.__class__, self.ssh_host, self.ssh_port, - self.ssh_proxy.cmd[1] if self.ssh_proxy else 'no', + self.ssh_proxy.cmd[1] if self.ssh_proxy else "no", self.ssh_username, credentials, - self.ssh_host_key if self.ssh_host_key else 'not checked', - '' if self.is_alive else 'not ', - 'disabled' if not self.set_keepalive else - 'every {0} sec'.format(self.set_keepalive), - 'disabled' if self.skip_tunnel_checkup else 'enabled', - '' if self._threaded else 'not ', - '' if self.compression else 'not ', + self.ssh_host_key if self.ssh_host_key else "not checked", + "" if self.is_alive else "not ", + "disabled" if not self.set_keepalive else "every {0} sec".format(self.set_keepalive), + "disabled" if self.skip_tunnel_checkup else "enabled", + "" if self._threaded else "not ", + "" if self.compression else "not ", logging.getLevelName(self.logger.level), self._local_binds, self._remote_binds, - )) + ) def __repr__(self): return self.__str__() @@ -1619,7 +1526,8 @@ def __del__(self): self.logger.warning( "It looks like you didn't call the .stop() before " "the SSHTunnelForwarder obj was collected by " - "the garbage collector! Running .stop(force=True)") + "the garbage collector! Running .stop(force=True)" + ) self.stop(force=True) @@ -1667,44 +1575,42 @@ def do_something(port): do_something(server.local_bind_port) """ # Attach a console handler to the logger or create one if not passed - loglevel = kwargs.pop('debug_level', None) - logger = kwargs.get('logger', None) or create_logger(loglevel=loglevel) - kwargs['logger'] = logger + loglevel = kwargs.pop("debug_level", None) + logger = kwargs.get("logger", None) or create_logger(loglevel=loglevel) + kwargs["logger"] = logger - ssh_address_or_host = kwargs.pop('ssh_address_or_host', None) + ssh_address_or_host = kwargs.pop("ssh_address_or_host", None) # Check if deprecated arguments ssh_address or ssh_host were used - for deprecated_argument in ['ssh_address', 'ssh_host']: - ssh_address_or_host = SSHTunnelForwarder._process_deprecated( - ssh_address_or_host, - deprecated_argument, - kwargs - ) + for deprecated_argument in ["ssh_address", "ssh_host"]: + ssh_address_or_host = SSHTunnelForwarder._process_deprecated(ssh_address_or_host, deprecated_argument, kwargs) - ssh_port = kwargs.pop('ssh_port', 22) - skip_tunnel_checkup = kwargs.pop('skip_tunnel_checkup', True) - block_on_close = kwargs.pop('block_on_close', None) + ssh_port = kwargs.pop("ssh_port", 22) + skip_tunnel_checkup = kwargs.pop("skip_tunnel_checkup", True) + block_on_close = kwargs.pop("block_on_close", None) if block_on_close: - warnings.warn("'block_on_close' is DEPRECATED. You should use either" - " .stop() or .stop(force=True), depends on what you do" - " with the active connections. This option has no" - " affect since 0.3.0", - DeprecationWarning) + warnings.warn( + "'block_on_close' is DEPRECATED. You should use either" + " .stop() or .stop(force=True), depends on what you do" + " with the active connections. This option has no" + " affect since 0.3.0", + DeprecationWarning, + ) if not args: if isinstance(ssh_address_or_host, tuple): - args = (ssh_address_or_host, ) + args = (ssh_address_or_host,) else: - args = ((ssh_address_or_host, ssh_port), ) + args = ((ssh_address_or_host, ssh_port),) forwarder = SSHTunnelForwarder(*args, **kwargs) forwarder.skip_tunnel_checkup = skip_tunnel_checkup return forwarder def _bindlist(input_str): - """ Define type of data expected for remote and local bind address lists - Returns a tuple (ip_address, port) whose elements are (str, int) + """Define type of data expected for remote and local bind address lists + Returns a tuple (ip_address, port) whose elements are (str, int) """ try: - ip_port = input_str.split(':') + ip_port = input_str.split(":") if len(ip_port) == 1: _ip = ip_port[0] _port = None @@ -1713,12 +1619,10 @@ def _bindlist(input_str): if not _ip and not _port: raise AssertionError elif not _port: - _port = '22' # default port if not given + _port = "22" # default port if not given return _ip, int(_port) except ValueError: - raise argparse.ArgumentTypeError( - 'Address tuple must be of type IP_ADDRESS:PORT' - ) + raise argparse.ArgumentTypeError("Address tuple must be of type IP_ADDRESS:PORT") except AssertionError: raise argparse.ArgumentTypeError("Both IP:PORT can't be missing!") @@ -1728,201 +1632,175 @@ def _parse_arguments(args=None): Parse arguments directly passed from CLI """ parser = argparse.ArgumentParser( - description='Pure python ssh tunnel utils\n' - 'Version {0}'.format(__version__), - formatter_class=argparse.RawTextHelpFormatter + description="Pure python ssh tunnel utils\nVersion {0}".format(__version__), + formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument( - 'ssh_address', + "ssh_address", type=str, - help='SSH server IP address (GW for SSH tunnels)\n' - 'set with "-- ssh_address" if immediately after ' - '-R or -L' + help='SSH server IP address (GW for SSH tunnels)\nset with "-- ssh_address" if immediately after -R or -L', ) - parser.add_argument( - '-U', '--username', - type=str, - dest='ssh_username', - help='SSH server account username' - ) + parser.add_argument("-U", "--username", type=str, dest="ssh_username", help="SSH server account username") parser.add_argument( - '-p', '--server_port', - type=int, - dest='ssh_port', - default=22, - help='SSH server TCP port (default: 22)' + "-p", "--server_port", type=int, dest="ssh_port", default=22, help="SSH server TCP port (default: 22)" ) - parser.add_argument( - '-P', '--password', - type=str, - dest='ssh_password', - help='SSH server account password' - ) + parser.add_argument("-P", "--password", type=str, dest="ssh_password", help="SSH server account password") parser.add_argument( - '-R', '--remote_bind_address', + "-R", + "--remote_bind_address", type=_bindlist, - nargs='+', + nargs="+", default=[], - metavar='IP:PORT', + metavar="IP:PORT", required=True, - dest='remote_bind_addresses', - help='Remote bind address sequence: ' - 'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n' - 'Equivalent to ssh -Lxxxx:IP_ADDRESS:PORT\n' - 'If port is omitted, defaults to 22.\n' - 'Example: -R 10.10.10.10: 10.10.10.10:5900' + dest="remote_bind_addresses", + help="Remote bind address sequence: " + "ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n" + "Equivalent to ssh -Lxxxx:IP_ADDRESS:PORT\n" + "If port is omitted, defaults to 22.\n" + "Example: -R 10.10.10.10: 10.10.10.10:5900", ) parser.add_argument( - '-L', '--local_bind_address', + "-L", + "--local_bind_address", type=_bindlist, - nargs='*', - dest='local_bind_addresses', - metavar='IP:PORT', - help='Local bind address sequence: ' - 'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n' - 'Elements may also be valid UNIX socket domains: \n' - '/tmp/foo.sock /tmp/bar.sock ... /tmp/baz.sock\n' - 'Equivalent to ssh -LPORT:xxxxxxxxx:xxxx, ' - 'being the local IP address optional.\n' - 'By default it will listen in all interfaces ' - '(0.0.0.0) and choose a random port.\n' - 'Example: -L :40000' + nargs="*", + dest="local_bind_addresses", + metavar="IP:PORT", + help="Local bind address sequence: " + "ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n" + "Elements may also be valid UNIX socket domains: \n" + "/tmp/foo.sock /tmp/bar.sock ... /tmp/baz.sock\n" + "Equivalent to ssh -LPORT:xxxxxxxxx:xxxx, " + "being the local IP address optional.\n" + "By default it will listen in all interfaces " + "(0.0.0.0) and choose a random port.\n" + "Example: -L :40000", ) - parser.add_argument( - '-k', '--ssh_host_key', - type=str, - help="Gateway's host key" - ) + parser.add_argument("-k", "--ssh_host_key", type=str, help="Gateway's host key") parser.add_argument( - '-K', '--private_key_file', - dest='ssh_private_key', - metavar='KEY_FILE', + "-K", + "--private_key_file", + dest="ssh_private_key", + metavar="KEY_FILE", type=str, - help='RSA/DSS/ECDSA private key file' + help="RSA/ECDSA/Ed25519 private key file", ) parser.add_argument( - '-S', '--private_key_password', - dest='ssh_private_key_password', - metavar='KEY_PASSWORD', + "-S", + "--private_key_password", + dest="ssh_private_key_password", + metavar="KEY_PASSWORD", type=str, - help='RSA/DSS/ECDSA private key password' + help="RSA/ECDSA/Ed25519 private key password", ) - parser.add_argument( - '-t', '--threaded', - action='store_true', - help='Allow concurrent connections to each tunnel' - ) + parser.add_argument("-t", "--threaded", action="store_true", help="Allow concurrent connections to each tunnel") parser.add_argument( - '-v', '--verbose', - action='count', + "-v", + "--verbose", + action="count", default=0, - help='Increase output verbosity (default: {0})'.format( - logging.getLevelName(DEFAULT_LOGLEVEL) - ) + help="Increase output verbosity (default: {0})".format(logging.getLevelName(DEFAULT_LOGLEVEL)), ) parser.add_argument( - '-V', '--version', - action='version', - version='%(prog)s {version}'.format(version=__version__), - help='Show version number and quit' + "-V", + "--version", + action="version", + version="%(prog)s {version}".format(version=__version__), + help="Show version number and quit", ) parser.add_argument( - '-x', '--proxy', + "-x", + "--proxy", type=_bindlist, - dest='ssh_proxy', - metavar='IP:PORT', - help='IP and port of SSH proxy to destination' + dest="ssh_proxy", + metavar="IP:PORT", + help="IP and port of SSH proxy to destination", ) parser.add_argument( - '-c', '--config', + "-c", + "--config", type=str, default=SSH_CONFIG_FILE, - dest='ssh_config_file', - help='SSH configuration file, defaults to {0}'.format(SSH_CONFIG_FILE) + dest="ssh_config_file", + help="SSH configuration file, defaults to {0}".format(SSH_CONFIG_FILE), ) parser.add_argument( - '-z', '--compress', - action='store_true', - dest='compression', - help='Request server for compression over SSH transport' + "-z", + "--compress", + action="store_true", + dest="compression", + help="Request server for compression over SSH transport", ) parser.add_argument( - '-n', '--noagent', - action='store_false', - dest='allow_agent', - help='Disable looking for keys from an SSH agent' + "-n", "--noagent", action="store_false", dest="allow_agent", help="Disable looking for keys from an SSH agent" ) parser.add_argument( - '-d', '--host_pkey_directories', - nargs='*', - dest='host_pkey_directories', - metavar='FOLDER', - help='List of directories where SSH pkeys (in the format `id_*`) ' - 'may be found' + "-d", + "--host_pkey_directories", + nargs="*", + dest="host_pkey_directories", + metavar="FOLDER", + help="List of directories where SSH pkeys (in the format `id_*`) may be found", ) return vars(parser.parse_args(args)) def _cli_main(args=None, **extras): - """ Pass input arguments to open_tunnel - - Mandatory: ssh_address, -R (remote bind address list) - - Optional: - -U (username) we may gather it from SSH_CONFIG_FILE or current username - -p (server_port), defaults to 22 - -P (password) - -L (local_bind_address), default to 0.0.0.0:22 - -k (ssh_host_key) - -K (private_key_file), may be gathered from SSH_CONFIG_FILE - -S (private_key_password) - -t (threaded), allow concurrent connections over tunnels - -v (verbose), up to 3 (-vvv) to raise loglevel from ERROR to DEBUG - -V (version) - -x (proxy), ProxyCommand's IP:PORT, may be gathered from config file - -c (ssh_config), ssh configuration file (defaults to SSH_CONFIG_FILE) - -z (compress) - -n (noagent), disable looking for keys from an Agent - -d (host_pkey_directories), look for keys on these folders + """Pass input arguments to open_tunnel + + Mandatory: ssh_address, -R (remote bind address list) + + Optional: + -U (username) we may gather it from SSH_CONFIG_FILE or current username + -p (server_port), defaults to 22 + -P (password) + -L (local_bind_address), default to 0.0.0.0:22 + -k (ssh_host_key) + -K (private_key_file), may be gathered from SSH_CONFIG_FILE + -S (private_key_password) + -t (threaded), allow concurrent connections over tunnels + -v (verbose), up to 3 (-vvv) to raise loglevel from ERROR to DEBUG + -V (version) + -x (proxy), ProxyCommand's IP:PORT, may be gathered from config file + -c (ssh_config), ssh configuration file (defaults to SSH_CONFIG_FILE) + -z (compress) + -n (noagent), disable looking for keys from an Agent + -d (host_pkey_directories), look for keys on these folders """ arguments = _parse_arguments(args) # Remove all "None" input values _remove_none_values(arguments) - verbosity = min(arguments.pop('verbose'), 4) - levels = [logging.ERROR, - logging.WARNING, - logging.INFO, - logging.DEBUG, - TRACE_LEVEL] - arguments.setdefault('debug_level', levels[verbosity]) - # do this while supporting py27/py34 instead of merging dicts - for (extra, value) in extras.items(): + verbosity = min(arguments.pop("verbose"), 4) + levels = [logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG, TRACE_LEVEL] + arguments.setdefault("debug_level", levels[verbosity]) + for extra, value in extras.items(): arguments.setdefault(extra, value) with open_tunnel(**arguments) as tunnel: if tunnel.is_alive: - input_(''' + input_(""" Press or to stop! - ''') + """) -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover _cli_main() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..eb5bce87 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,74 @@ +from dataclasses import dataclass +from pathlib import Path +from textwrap import dedent + +import paramiko +import pytest + +PKEY_PASSWORD = "sshtunnel" + + +@dataclass(frozen=True) +class SSHKeyFixture: + """Ephemeral SSH key material generated for a test session. + + Paths are stored as str because sshtunnel's API uses isinstance(pkey, str) + checks internally. Use Path objects only for file I/O within the fixture. + """ + + dir: Path + plain_key_path: str + encrypted_key_path: str + config_path: str + rsa_key: paramiko.RSAKey + fingerprint: bytes + + @property + def fingerprints(self) -> dict[str, bytes]: + return {"ssh-rsa": self.fingerprint} + + +@pytest.fixture(scope="session") +def ssh_keys(tmp_path_factory: pytest.TempPathFactory) -> SSHKeyFixture: + """Generate ephemeral RSA keys, an encrypted variant, and an SSH config.""" + tmp_dir = tmp_path_factory.mktemp("sshtunnel-keys") + key = paramiko.RSAKey.generate(bits=2048) + + plain_path = tmp_dir / "testrsa.key" + key.write_private_key_file(str(plain_path)) + plain_path.chmod(0o600) + + encrypted_path = tmp_dir / "testrsa_encrypted.key" + key.write_private_key_file(str(encrypted_path), password=PKEY_PASSWORD) + encrypted_path.chmod(0o600) + + config_path = tmp_dir / "testconfig" + config_path.write_text( + dedent(f"""\ + Host * + User test + Compression yes + IdentityFile {plain_path} + Host test + ProxyCommand ssh -q -W %h:%p sshproxy + Host other + Port 222 + Hostname 10.0.0.1 + """) + ) + + return SSHKeyFixture( + dir=tmp_dir, + plain_key_path=str(plain_path), + encrypted_key_path=str(encrypted_path), + config_path=str(config_path), + rsa_key=key, + fingerprint=key.get_fingerprint(), + ) + + +@pytest.fixture(autouse=True) +def _inject_ssh_keys(request, ssh_keys): + """Make ssh_keys available to unittest.TestCase classes as cls.ssh_keys.""" + if request.cls is not None: + request.cls.ssh_keys = ssh_keys diff --git a/tests/requirements-syntax.txt b/tests/requirements-syntax.txt deleted file mode 100644 index 8e66e79c..00000000 --- a/tests/requirements-syntax.txt +++ /dev/null @@ -1,8 +0,0 @@ -bashtest -check-manifest -docutils -flake8 -mccabe -pygments -readme -twine diff --git a/tests/requirements.txt b/tests/requirements.txt index 6a91ea46..78b1a59c 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,13 +1,3 @@ -coveralls -mock pytest pytest-cov pytest-xdist -twine -# required by twine! -bleach<5.0.0 -# readme-renderer (required by twine) 25.0 has removed support for Python 3.4 -readme-renderer<25.0; python_version == '3.4' -# try to solve CI problem -importlib-metadata==1.7.0; python_version == '3.5' -importlib-metadata==1.1.3; python_version == '3.4' diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 40662d08..95a2b442 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -1,5 +1,3 @@ -from __future__ import with_statement - import os import sys import random @@ -8,66 +6,43 @@ import getpass import logging import argparse -import warnings import threading +import unittest +from io import StringIO from os import path, linesep from functools import partial from contextlib import contextmanager +from unittest import mock -import mock import paramiko import sshtunnel import shutil import tempfile -if sys.version_info[0] == 2: - from cStringIO import StringIO - if sys.version_info < (2, 7): - import unittest2 as unittest - else: - import unittest -else: - import unittest - from io import StringIO - +from .conftest import PKEY_PASSWORD -# UTILS def get_random_string(length=12): - """ - >>> r = get_random_string(1) - >>> r in asciis - True - >>> r = get_random_string(2) - >>> [r[0] in asciis, r[1] in asciis] - [True, True] - """ - ascii_lowercase = 'abcdefghijklmnopqrstuvwxyz' - ascii_uppercase = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' - digits = '0123456789' + ascii_lowercase = "abcdefghijklmnopqrstuvwxyz" + ascii_uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + digits = "0123456789" asciis = ascii_lowercase + ascii_uppercase + digits - return ''.join([random.choice(asciis) for _ in range(length)]) - - -def get_test_data_path(x): - return path.join(HERE, x) + return "".join([random.choice(asciis) for _ in range(length)]) @contextmanager def capture_stdout_stderr(): - (old_out, old_err) = (sys.stdout, sys.stderr) + old_out, old_err = sys.stdout, sys.stderr try: out = [StringIO(), StringIO()] - (sys.stdout, sys.stderr) = out + sys.stdout, sys.stderr = out yield out finally: - (sys.stdout, sys.stderr) = (old_out, old_err) + sys.stdout, sys.stderr = old_out, old_err out[0] = out[0].getvalue() out[1] = out[1].getvalue() -# Ensure that ``ssh_config_file is None`` during tests, exceptions are not -# raised and pkey loading from an SSH agent is disabled open_tunnel = partial( sshtunnel.open_tunnel, mute_exceptions=False, @@ -77,25 +52,12 @@ def capture_stdout_stderr(): host_pkey_directories=[], ) -# CONSTANTS - SSH_USERNAME = get_random_string() SSH_PASSWORD = get_random_string() -SSH_DSS = b'\x44\x78\xf0\xb9\xa2\x3c\xc5\x18\x20\x09\xff\x75\x5b\xc1\xd2\x6c' -SSH_RSA = b'\x60\x73\x38\x44\xcb\x51\x86\x65\x7f\xde\xda\xa2\x2b\x5a\x57\xd5' -ECDSA = b'\x25\x19\xeb\x55\xe6\xa1\x47\xff\x4f\x38\xd2\x75\x6f\xa5\xd5\x60' -FINGERPRINTS = { - 'ssh-dss': SSH_DSS, - 'ssh-rsa': SSH_RSA, - 'ecdsa-sha2-nistp256': ECDSA, -} DAEMON_THREADS = False HERE = path.abspath(path.dirname(__file__)) THREADS_TIMEOUT = 5.0 -PKEY_FILE = 'testrsa.key' -ENCRYPTED_PKEY_FILE = 'testrsa_encrypted.key' -TEST_CONFIG_FILE = 'testconfig' -TEST_UNIX_SOCKET = get_test_data_path('test_socket') +TEST_UNIX_SOCKET = path.join(HERE, "test_socket") sshtunnel.TRACE = True sshtunnel.SSH_TIMEOUT = 1.0 @@ -103,6 +65,7 @@ def capture_stdout_stderr(): # TESTS + class MockLoggingHandler(logging.Handler, object): """Mock logging handler to check for expected logs. @@ -111,8 +74,7 @@ class MockLoggingHandler(logging.Handler, object): """ def __init__(self, *args, **kwargs): - self.messages = {'debug': [], 'info': [], 'warning': [], 'error': [], - 'critical': [], 'trace': []} + self.messages = {"debug": [], "info": [], "warning": [], "error": [], "critical": [], "trace": []} super(MockLoggingHandler, self).__init__(*args, **kwargs) def emit(self, record): @@ -134,68 +96,61 @@ def reset(self): class NullServer(paramiko.ServerInterface): def __init__(self, *args, **kwargs): - # Allow tests to enable/disable specific key types - self.__allowed_keys = kwargs.pop('allowed_keys', []) - self.log = kwargs.pop('log', sshtunnel.create_logger(loglevel='DEBUG')) - super(NullServer, self).__init__(*args, **kwargs) + self.__allowed_keys = kwargs.pop("allowed_keys", []) + self.__fingerprints = kwargs.pop("fingerprints", {}) + self.log = kwargs.pop("log", sshtunnel.create_logger(loglevel="DEBUG")) + super().__init__(*args, **kwargs) def check_channel_forward_agent_request(self, channel): - self.log.debug('NullServer.check_channel_forward_agent_request() {0}' - .format(channel)) + self.log.debug("NullServer.check_channel_forward_agent_request() {0}".format(channel)) return False def get_allowed_auths(self, username): - allowed_auths = 'publickey{0}'.format( - ',password' if username == SSH_USERNAME else '' - ) - self.log.debug('NullServer >> allowed auths for {0}: {1}' - .format(username, allowed_auths)) + allowed_auths = "publickey{0}".format(",password" if username == SSH_USERNAME else "") + self.log.debug("NullServer >> allowed auths for {0}: {1}".format(username, allowed_auths)) return allowed_auths def check_auth_password(self, username, password): - _ok = (username == SSH_USERNAME and password == SSH_PASSWORD) - self.log.debug('NullServer >> password for {0} {1}OK' - .format(username, '' if _ok else 'NOT-')) + _ok = username == SSH_USERNAME and password == SSH_PASSWORD + self.log.debug("NullServer >> password for {0} {1}OK".format(username, "" if _ok else "NOT-")) return paramiko.AUTH_SUCCESSFUL if _ok else paramiko.AUTH_FAILED def check_auth_publickey(self, username, key): try: - expected = FINGERPRINTS[key.get_name()] - _ok = (key.get_name() in self.__allowed_keys and - key.get_fingerprint() == expected) + expected = self.__fingerprints[key.get_name()] + _ok = key.get_name() in self.__allowed_keys and key.get_fingerprint() == expected except KeyError: _ok = False - self.log.debug('NullServer >> pkey authentication for {0} {1}OK' - .format(username, '' if _ok else 'NOT-')) + self.log.debug("NullServer >> pkey authentication for {0} {1}OK".format(username, "" if _ok else "NOT-")) return paramiko.AUTH_SUCCESSFUL if _ok else paramiko.AUTH_FAILED def check_channel_request(self, kind, chanid): - self.log.debug('NullServer.check_channel_request()') + self.log.debug("NullServer.check_channel_request()") return paramiko.OPEN_SUCCEEDED def check_channel_exec_request(self, channel, command): - self.log.debug('NullServer.check_channel_exec_request()') + self.log.debug("NullServer.check_channel_exec_request()") return True def check_port_forward_request(self, address, port): - self.log.debug('NullServer.check_port_forward_request()') + self.log.debug("NullServer.check_port_forward_request()") return True def check_global_request(self, kind, msg): - self.log.debug('NullServer.check_port_forward_request()') + self.log.debug("NullServer.check_port_forward_request()") return True def check_channel_direct_tcpip_request(self, chanid, origin, destination): - self.log.debug('NullServer.check_channel_direct_tcpip_request' - '(chanid={0}) {1} -> {2}' - .format(chanid, origin, destination)) + self.log.debug( + "NullServer.check_channel_direct_tcpip_request(chanid={0}) {1} -> {2}".format(chanid, origin, destination) + ) return paramiko.OPEN_SUCCEEDED class SSHClientTest(unittest.TestCase): def make_socket(self): s = socket.socket() - s.bind(('localhost', 0)) + s.bind(("localhost", 0)) s.listen(5) addr, port = s.getsockname() return s, addr, port @@ -205,27 +160,23 @@ def setUpClass(cls): super(SSHClientTest, cls).setUpClass() socket.setdefaulttimeout(sshtunnel.SSH_TIMEOUT) cls.log = logging.getLogger(sshtunnel.__name__) - cls.log = sshtunnel.create_logger(logger=cls.log, - loglevel='DEBUG') - cls._sshtunnel_log_handler = MockLoggingHandler(level='DEBUG') + cls.log = sshtunnel.create_logger(logger=cls.log, loglevel="DEBUG") + cls._sshtunnel_log_handler = MockLoggingHandler(level="DEBUG") cls.log.addHandler(cls._sshtunnel_log_handler) cls.sshtunnel_log_messages = cls._sshtunnel_log_handler.messages # set verbose format for logging - _fmt = '%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/' \ - '%(lineno)04d@%(module)-10.9s| %(message)s' + _fmt = "%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/%(lineno)04d@%(module)-10.9s| %(message)s" for handler in cls.log.handlers: handler.setFormatter(logging.Formatter(_fmt)) def setUp(self): super(SSHClientTest, self).setUp() - self.log.debug('*' * 80) - self.log.info('setUp for: {0}()'.format(self._testMethodName.upper())) + self.log.debug("*" * 80) + self.log.info("setUp for: {0}()".format(self._testMethodName.upper())) self.ssockl, self.saddr, self.sport = self.make_socket() self.esockl, self.eaddr, self.eport = self.make_socket() - self.log.info("Socket for ssh-server: {0}:{1}" - .format(self.saddr, self.sport)) - self.log.info("Socket for echo-server: {0}:{1}" - .format(self.eaddr, self.eport)) + self.log.info("Socket for ssh-server: {0}:{1}".format(self.saddr, self.sport)) + self.log.info("Socket for echo-server: {0}:{1}".format(self.eaddr, self.eport)) self.ssh_event = threading.Event() self.running_threads = [] @@ -235,47 +186,40 @@ def setUp(self): self._sshtunnel_log_handler.reset() def tearDown(self): - self.log.info('tearDown for: {0}()' - .format(self._testMethodName.upper())) + self.log.info("tearDown for: {0}()".format(self._testMethodName.upper())) self.stop_echo_and_ssh_server() for thread in self.running_threads: x = self.threads[thread] - self.log.info('thread {0} ({1})' - .format(thread, - 'alive' if x.is_alive() else 'defunct')) + self.log.info("thread {0} ({1})".format(thread, "alive" if x.is_alive() else "defunct")) while self.running_threads: for thread in self.running_threads: x = self.threads[thread] - self.wait_for_thread(self.threads[thread], - who='tearDown') + self.wait_for_thread(self.threads[thread], who="tearDown") if not x.is_alive(): - self.log.info('thread {0} now stopped'.format(thread)) + self.log.info("thread {0} now stopped".format(thread)) - for attr in ['server', 'tc', 'ts', 'socks', 'ssockl', 'esockl']: + for attr in ["server", "tc", "ts", "socks", "ssockl", "esockl"]: if hasattr(self, attr): - self.log.info('tearDown() {0}'.format(attr)) + self.log.info("tearDown() {0}".format(attr)) getattr(self, attr).close() def wait_for_thread(self, thread, timeout=THREADS_TIMEOUT, who=None): if thread.is_alive(): - self.log.debug('{0}waiting for {1} to end...' - .format('{0} '.format(who) if who else '', - thread.name)) + self.log.debug("{0}waiting for {1} to end...".format("{0} ".format(who) if who else "", thread.name)) thread.join(timeout) def start_echo_and_ssh_server(self): self.is_server_working = True self.start_echo_server() - t = threading.Thread(target=self._run_ssh_server, - name='ssh-server') + t = threading.Thread(target=self._run_ssh_server, name="ssh-server") t.daemon = DAEMON_THREADS self.running_threads.append(t.name) self.threads[t.name] = t t.start() def stop_echo_and_ssh_server(self): - self.log.info('Sending STOP signal') + self.log.info("Sending STOP signal") self.is_server_working = False def _check_server_auth(self): @@ -283,8 +227,7 @@ def _check_server_auth(self): self.ssh_event.wait(sshtunnel.SSH_TIMEOUT) # wait for transport self.assertTrue(self.ssh_event.is_set()) self.assertTrue(self.ts.is_active()) - self.assertEqual(self.ts.get_username(), - SSH_USERNAME) + self.assertEqual(self.ts.get_username(), SSH_USERNAME) self.assertTrue(self.ts.is_authenticated()) @contextmanager @@ -297,128 +240,116 @@ def _test_server(self, *args, **kwargs): server._stop_transport() def start_echo_server(self): - t = threading.Thread(target=self._run_echo_server, - name='echo-server') + t = threading.Thread(target=self._run_echo_server, name="echo-server") t.daemon = DAEMON_THREADS self.running_threads.append(t.name) self.threads[t.name] = t t.start() def _run_ssh_server(self): - self.log.info('ssh-server Start') + self.log.info("ssh-server Start") try: self.socks, addr = self.ssockl.accept() except socket.timeout: - self.log.error('ssh-server connection timed out!') - self.running_threads.remove('ssh-server') + self.log.error("ssh-server connection timed out!") + self.running_threads.remove("ssh-server") return self.ts = paramiko.Transport(self.socks) - host_key = paramiko.RSAKey.from_private_key_file( - get_test_data_path(PKEY_FILE) - ) + host_key = paramiko.RSAKey.from_private_key_file(self.ssh_keys.plain_key_path) self.ts.add_server_key(host_key) - server = NullServer(allowed_keys=FINGERPRINTS.keys(), - log=self.log) - t = threading.Thread(target=self._do_forwarding, - name='forward-server') + server = NullServer( + allowed_keys=self.ssh_keys.fingerprints.keys(), fingerprints=self.ssh_keys.fingerprints, log=self.log + ) + t = threading.Thread(target=self._do_forwarding, name="forward-server") t.daemon = DAEMON_THREADS self.running_threads.append(t.name) self.threads[t.name] = t t.start() self.ts.start_server(self.ssh_event, server) - self.wait_for_thread(t, - timeout=None, - who='ssh-server') - self.log.info('ssh-server shutting down') - self.running_threads.remove('ssh-server') + self.wait_for_thread(t, timeout=None, who="ssh-server") + self.log.info("ssh-server shutting down") + self.running_threads.remove("ssh-server") def _run_echo_server(self, timeout=sshtunnel.SSH_TIMEOUT): - self.log.info('echo-server Started') + self.log.info("echo-server Started") self.ssh_event.wait(timeout) # wait for transport socks = [self.esockl] try: while self.is_server_working: - inputready, _, _ = select.select(socks, - [], - [], - timeout) + inputready, _, _ = select.select(socks, [], [], timeout) for s in inputready: if s == self.esockl: # handle the server socket try: client, address = self.esockl.accept() - self.log.info('echo-server accept() {0}' - .format(address)) + self.log.info("echo-server accept() {0}".format(address)) except OSError: - self.log.info('echo-server accept() OSError') + self.log.info("echo-server accept() OSError") break socks.append(client) else: # handle all other sockets try: data = s.recv(1000) - self.log.info('echo-server echoing {0}' - .format(data)) + self.log.info("echo-server echoing {0}".format(data)) s.send(data) except OSError: - self.log.warning('echo-server OSError') + self.log.warning("echo-server OSError") continue finally: s.close() socks.remove(s) - self.log.info('<<< echo-server received STOP signal') + self.log.info("<<< echo-server received STOP signal") except Exception as e: - self.log.info('echo-server got Exception: {0}'.format(repr(e))) + self.log.info("echo-server got Exception: {0}".format(repr(e))) finally: self.is_server_working = False - if 'forward-server' in self.threads: - t = self.threads['forward-server'] - self.wait_for_thread(t, timeout=None, who='echo-server') - self.running_threads.remove('forward-server') + if "forward-server" in self.threads: + t = self.threads["forward-server"] + self.wait_for_thread(t, timeout=None, who="echo-server") + self.running_threads.remove("forward-server") for s in socks: s.close() - self.log.info('echo-server shutting down') - self.running_threads.remove('echo-server') + self.log.info("echo-server shutting down") + self.running_threads.remove("echo-server") def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): - self.log.debug('forward-server Start') + self.log.debug("forward-server Start") self.ssh_event.wait(THREADS_TIMEOUT) # wait for SSH server's transport + schan = None + echo = None + info = "forward-server schan <> echo" try: schan = self.ts.accept(timeout=timeout) - info = "forward-server schan <> echo" + if schan is None: + self.log.info("forward-server accept() returned None") + return self.log.info(info + " accept()") - echo = socket.create_connection( - (self.eaddr, self.eport) - ) + echo = socket.create_connection((self.eaddr, self.eport)) while self.is_server_working: - rqst, _, _ = select.select([schan, echo], - [], - [], - timeout) + rqst, _, _ = select.select([schan, echo], [], [], timeout) if schan in rqst: data = schan.recv(1024) - self.log.debug('{0} -->: {1}'.format(info, repr(data))) + self.log.debug("{0} -->: {1}".format(info, repr(data))) echo.send(data) if len(data) == 0: break if echo in rqst: data = echo.recv(1024) - self.log.debug('{0} <--: {1}'.format(info, repr(data))) + self.log.debug("{0} <--: {1}".format(info, repr(data))) schan.send(data) if len(data) == 0: break - self.log.info('<<< forward-server received STOP signal') + self.log.info("<<< forward-server received STOP signal") except socket.error: - self.log.critical('{0} sending RST'.format(info)) - # except Exception as e: - # # we reach this point usually when schan is None (paramiko bug?) - # self.log.critical(repr(e)) + self.log.critical("{0} sending RST".format(info)) finally: if schan: - self.log.debug('{0} closing connection...'.format(info)) + self.log.debug("{0} closing connection...".format(info)) schan.close() + if echo: echo.close() - self.log.debug('{0} connection closed.'.format(info)) + self.log.debug("{0} connection closed.".format(info)) def randomize_eport(self): return random.randint(49152, 65535) @@ -432,19 +363,18 @@ def test_echo_server(self): logger=self.log, ) as server: message = get_random_string().encode() - local_bind_addr = ('127.0.0.1', server.local_bind_port) - self.log.info('_test_server(): try connect!') + local_bind_addr = ("127.0.0.1", server.local_bind_port) + self.log.info("_test_server(): try connect!") s = socket.create_connection(local_bind_addr) - self.log.info('_test_server(): connected from {0}! try send!' - .format(s.getsockname())) + self.log.info("_test_server(): connected from {0}! try send!".format(s.getsockname())) s.send(message) - self.log.info('_test_server(): sent!') - z = (s.recv(1000)) + self.log.info("_test_server(): sent!") + z = s.recv(1000) self.assertEqual(z, message) s.close() def test_connect_by_username_password(self): - """ Test connecting using username/password as authentication """ + """Test connecting using username/password as authentication""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -455,21 +385,19 @@ def test_connect_by_username_password(self): pass # no exceptions are raised def test_connect_by_rsa_key_file(self): - """ Test connecting using a RSA key file """ + """Test connecting using a RSA key file""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, - ssh_pkey=get_test_data_path(PKEY_FILE), + ssh_pkey=self.ssh_keys.plain_key_path, remote_bind_address=(self.eaddr, self.eport), logger=self.log, ): pass # no exceptions are raised def test_connect_by_paramiko_key(self): - """ Test connecting when ssh_private_key is a paramiko.RSAKey """ - ssh_key = paramiko.RSAKey.from_private_key_file( - get_test_data_path(PKEY_FILE) - ) + """Test connecting when ssh_private_key is a paramiko.RSAKey""" + ssh_key = paramiko.RSAKey.from_private_key_file(self.ssh_keys.plain_key_path) with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -480,7 +408,7 @@ def test_connect_by_paramiko_key(self): pass def test_open_tunnel(self): - """ Test wrapper method mainly used from CLI """ + """Test wrapper method mainly used from CLI""" server = sshtunnel.open_tunnel( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -506,28 +434,34 @@ def test_sshaddress_and_sshaddressorhost_mutually_exclusive(self): Test that deprecate argument ssh_address cannot be used together with ssh_address_or_host """ - with self.assertRaises(ValueError): - open_tunnel( - ssh_address_or_host=(self.saddr, self.sport), - ssh_address=(self.saddr, self.sport), - ssh_username=SSH_USERNAME, - ssh_password=SSH_PASSWORD, - remote_bind_address=(self.eaddr, self.eport), - ) + import pytest + + with pytest.warns(DeprecationWarning, match="ssh_address"): + with self.assertRaises(ValueError): + open_tunnel( + ssh_address_or_host=(self.saddr, self.sport), + ssh_address=(self.saddr, self.sport), + ssh_username=SSH_USERNAME, + ssh_password=SSH_PASSWORD, + remote_bind_address=(self.eaddr, self.eport), + ) def test_sshhost_and_sshaddressorhost_mutually_exclusive(self): """ Test that deprecate argument ssh_host cannot be used together with ssh_address_or_host """ - with self.assertRaises(ValueError): - open_tunnel( - ssh_address_or_host=(self.saddr, self.sport), - ssh_host=(self.saddr, self.sport), - ssh_username=SSH_USERNAME, - ssh_password=SSH_PASSWORD, - remote_bind_address=(self.eaddr, self.eport), - ) + import pytest + + with pytest.warns(DeprecationWarning, match="ssh_host"): + with self.assertRaises(ValueError): + open_tunnel( + ssh_address_or_host=(self.saddr, self.sport), + ssh_host=(self.saddr, self.sport), + ssh_username=SSH_USERNAME, + ssh_password=SSH_PASSWORD, + remote_bind_address=(self.eaddr, self.eport), + ) def test_sshaddressorhost_may_not_be_a_tuple(self): """ @@ -550,7 +484,7 @@ def test_unknown_argument_raises_exception(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - i_do_not_exist=0 + i_do_not_exist=0, ) def test_more_local_than_remote_bind_sizes_raises_exception(self): @@ -564,8 +498,7 @@ def test_more_local_than_remote_bind_sizes_raises_exception(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_addresses=[('127.0.0.1', self.eport), - ('127.0.0.1', self.randomize_eport())] + local_bind_addresses=[("127.0.0.1", self.eport), ("127.0.0.1", self.randomize_eport())], ) def test_localbindaddress_and_localbindaddresses_mutually_exclusive(self): @@ -579,9 +512,8 @@ def test_localbindaddress_and_localbindaddresses_mutually_exclusive(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_address=('127.0.0.1', self.eport), - local_bind_addresses=[('127.0.0.1', self.eport), - ('127.0.0.1', self.randomize_eport())] + local_bind_address=("127.0.0.1", self.eport), + local_bind_addresses=[("127.0.0.1", self.eport), ("127.0.0.1", self.randomize_eport())], ) def test_localbindaddress_host_is_optional(self): @@ -594,9 +526,9 @@ def test_localbindaddress_host_is_optional(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_address=('', self.randomize_eport()) + local_bind_address=("", self.randomize_eport()), ) as server: - self.assertEqual(server.local_bind_host, '0.0.0.0') + self.assertEqual(server.local_bind_host, "0.0.0.0") def test_localbindaddress_port_is_optional(self): """ @@ -608,7 +540,7 @@ def test_localbindaddress_port_is_optional(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_address=('127.0.0.1', ) + local_bind_address=("127.0.0.1",), ) as server: self.assertIsInstance(server.local_bind_port, int) @@ -623,8 +555,7 @@ def test_remotebindaddress_and_remotebindaddresses_are_exclusive(self): ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - remote_bind_addresses=[(self.eaddr, self.eport), - (self.eaddr, self.randomize_eport())] + remote_bind_addresses=[(self.eaddr, self.eport), (self.eaddr, self.randomize_eport())], ) def test_no_remote_bind_address_raises_exception(self): @@ -638,28 +569,24 @@ def test_no_remote_bind_address_raises_exception(self): ssh_username=SSH_USERNAME, ) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_reading_from_a_bad_sshconfigfile_does_not_raise_error(self): """ Test that when a bad ssh_config file is found, a warning is shown but no exception is raised """ - ssh_config_file = 'not_existing_file' + ssh_config_file = "not_existing_file" open_tunnel( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_address=('127.0.0.1', self.randomize_eport()), + local_bind_address=("127.0.0.1", self.randomize_eport()), logger=self.log, - ssh_config_file=ssh_config_file + ssh_config_file=ssh_config_file, ) - logged_message = 'Could not read SSH configuration file: {0}'.format( - ssh_config_file - ) - self.assertIn(logged_message, self.sshtunnel_log_messages['warning']) + logged_message = "Could not read SSH configuration file: {0}".format(ssh_config_file) + self.assertIn(logged_message, self.sshtunnel_log_messages["warning"]) def test_not_setting_password_or_pkey_raises_error(self): """ @@ -671,51 +598,42 @@ def test_not_setting_password_or_pkey_raises_error(self): (self.saddr, self.sport), ssh_username=SSH_USERNAME, remote_bind_address=(self.eaddr, self.eport), - ssh_config_file=None + ssh_config_file=None, ) - @unittest.skipIf(os.name == 'nt', - reason='Need to fix test on Windows') + @unittest.skipIf(os.name == "nt", reason="Need to fix test on Windows") def test_deprecate_warnings_are_shown(self): """Test that when using deprecate arguments a warning is logged""" - warnings.simplefilter('always') # don't ignore DeprecationWarnings - - with warnings.catch_warnings(record=True) as w: - for deprecated_arg in ['ssh_address', 'ssh_host']: - _kwargs = { - deprecated_arg: (self.saddr, self.sport), - 'ssh_username': SSH_USERNAME, - 'ssh_password': SSH_PASSWORD, - 'remote_bind_address': (self.eaddr, self.eport), - } - open_tunnel(**_kwargs) - logged_message = "'{0}' is DEPRECATED use '{1}' instead"\ - .format(deprecated_arg, - sshtunnel._DEPRECATIONS[deprecated_arg]) - self.assertTrue(issubclass(w[-1].category, DeprecationWarning)) - self.assertEqual(logged_message, str(w[-1].message)) - - # other deprecated arguments - with warnings.catch_warnings(record=True) as w: - for deprecated_arg in [ - 'raise_exception_if_any_forwarder_have_a_problem', - 'ssh_private_key' - ]: - _kwargs = { - 'ssh_address_or_host': (self.saddr, self.sport), - 'ssh_username': SSH_USERNAME, - 'ssh_password': SSH_PASSWORD, - 'remote_bind_address': (self.eaddr, self.eport), - deprecated_arg: (self.saddr, self.sport), - } - open_tunnel(**_kwargs) - logged_message = "'{0}' is DEPRECATED use '{1}' instead"\ - .format(deprecated_arg, - sshtunnel._DEPRECATIONS[deprecated_arg]) - self.assertTrue(issubclass(w[-1].category, DeprecationWarning)) - self.assertEqual(logged_message, str(w[-1].message)) - - warnings.simplefilter('default') + import pytest + + for deprecated_arg in ["ssh_address", "ssh_host"]: + expected = "'{0}' is DEPRECATED use '{1}' instead".format( + deprecated_arg, sshtunnel._DEPRECATIONS[deprecated_arg] + ) + with pytest.warns(DeprecationWarning, match=expected): + open_tunnel( + **{ + deprecated_arg: (self.saddr, self.sport), + "ssh_username": SSH_USERNAME, + "ssh_password": SSH_PASSWORD, + "remote_bind_address": (self.eaddr, self.eport), + } + ) + + for deprecated_arg in ["raise_exception_if_any_forwarder_have_a_problem", "ssh_private_key"]: + expected = "'{0}' is DEPRECATED use '{1}' instead".format( + deprecated_arg, sshtunnel._DEPRECATIONS[deprecated_arg] + ) + with pytest.warns(DeprecationWarning, match=expected): + open_tunnel( + **{ + "ssh_address_or_host": (self.saddr, self.sport), + "ssh_username": SSH_USERNAME, + "ssh_password": SSH_PASSWORD, + "remote_bind_address": (self.eaddr, self.eport), + deprecated_arg: (self.saddr, self.sport), + } + ) def test_gateway_unreachable_raises_exception(self): """ @@ -732,8 +650,6 @@ def test_gateway_unreachable_raises_exception(self): ): pass - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_gateway_ip_unresolvable_raises_exception(self): """ BaseSSHTunnelForwarderError is raised when not able to resolve the @@ -749,50 +665,38 @@ def test_gateway_ip_unresolvable_raises_exception(self): ): pass self.assertIn( - 'Could not resolve IP address for {0}, aborting!'.format( - SSH_USERNAME - ), - self.sshtunnel_log_messages['error'] + "Could not resolve IP address for {0}, aborting!".format(SSH_USERNAME), self.sshtunnel_log_messages["error"] ) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_running_start_twice_logs_warning(self): """Test that when running start() twice a warning is shown""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, - remote_bind_address=(self.eaddr, self.eport) + remote_bind_address=(self.eaddr, self.eport), ) as server: - self.assertNotIn('Already started!', - self.sshtunnel_log_messages['warning']) + self.assertNotIn("Already started!", self.sshtunnel_log_messages["warning"]) server.logger.error(server.is_active) server.logger.error(server.is_alive) server.start() # 2nd start should prompt the warning - self.assertIn('Already started!', - self.sshtunnel_log_messages['warning']) + self.assertIn("Already started!", self.sshtunnel_log_messages["warning"]) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_stop_before_start_logs_warning(self): """ Test that running .stop() on an already stopped server logs a warning """ server = open_tunnel( - '10.10.10.10', + "10.10.10.10", ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, - remote_bind_address=('10.0.0.1', 8080), + remote_bind_address=("10.0.0.1", 8080), mute_exceptions=True, logger=self.log, ) server.stop() - self.assertIn('Server is not started. Please .start() first!', - self.sshtunnel_log_messages['warning']) + self.assertIn("Server is not started. Please .start() first!", self.sshtunnel_log_messages["warning"]) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_wrong_auth_to_gateway_logs_error(self): """ Test that when connecting to the ssh gateway with wrong credentials, @@ -807,16 +711,13 @@ def test_wrong_auth_to_gateway_logs_error(self): logger=self.log, ): pass - self.assertIn('Could not open connection to gateway', - self.sshtunnel_log_messages['error']) + self.assertIn("Could not open connection to gateway", self.sshtunnel_log_messages["error"]) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_missing_pkey_file_logs_warning(self): """ Test that when the private key file is missing, a warning is logged """ - bad_pkey = 'this_file_does_not_exist' + bad_pkey = "this_file_does_not_exist" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -825,13 +726,11 @@ def test_missing_pkey_file_logs_warning(self): remote_bind_address=(self.eaddr, self.eport), logger=self.log, ): - self.assertIn('Private key file not found: {0}'.format(bad_pkey), - self.sshtunnel_log_messages['warning']) + self.assertIn("Private key file not found: {0}".format(bad_pkey), self.sshtunnel_log_messages["warning"]) def test_connect_via_proxy(self): - """ Test connecting using a ProxyCommand """ - proxycmd = paramiko.proxy.ProxyCommand('ssh proxy -W {0}:{1}' - .format(self.saddr, self.sport)) + """Test connecting using a ProxyCommand""" + proxycmd = paramiko.proxy.ProxyCommand("ssh proxy -W {0}:{1}".format(self.saddr, self.sport)) server = open_tunnel( self.saddr, ssh_username=SSH_USERNAME, @@ -841,12 +740,10 @@ def test_connect_via_proxy(self): ssh_proxy_enabled=True, logger=self.log, ) - self.assertEqual(server.ssh_proxy.cmd[1], 'proxy') + self.assertEqual(server.ssh_proxy.cmd[1], "proxy") - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_can_skip_loading_sshconfig(self): - """ Test that we can skip loading the ~/.ssh/config file """ + """Test that we can skip loading the ~/.ssh/config file""" server = open_tunnel( (self.saddr, self.sport), ssh_password=SSH_PASSWORD, @@ -855,13 +752,12 @@ def test_can_skip_loading_sshconfig(self): logger=self.log, ) self.assertEqual(server.ssh_username, getpass.getuser()) - self.assertIn('Skipping loading of ssh configuration file', - self.sshtunnel_log_messages['info']) + self.assertIn("Skipping loading of ssh configuration file", self.sshtunnel_log_messages["info"]) def test_local_bind_port(self): - """ Test local_bind_port property """ + """Test local_bind_port property""" s = socket.socket() - s.bind(('localhost', 0)) + s.bind(("localhost", 0)) addr, port = s.getsockname() s.close() with self._test_server( @@ -876,7 +772,7 @@ def test_local_bind_port(self): self.assertEqual(server.local_bind_port, port) def test_local_bind_host(self): - """ Test local_bind_host property """ + """Test local_bind_host property""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -889,9 +785,9 @@ def test_local_bind_host(self): self.assertEqual(server.local_bind_host, self.saddr) def test_local_bind_address(self): - """ Test local_bind_address property """ + """Test local_bind_address property""" s = socket.socket() - s.bind(('localhost', 0)) + s.bind(("localhost", 0)) addr, port = s.getsockname() s.close() with self._test_server( @@ -906,13 +802,12 @@ def test_local_bind_address(self): self.assertTupleEqual(server.local_bind_address, (addr, port)) def test_local_bind_ports(self): - """ Test local_bind_ports property """ + """Test local_bind_ports property""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, - remote_bind_addresses=[(self.eaddr, self.eport), - (self.saddr, self.sport)], + remote_bind_addresses=[(self.eaddr, self.eport), (self.saddr, self.sport)], logger=self.log, ) as server: self.assertIsInstance(server.local_bind_ports, list) @@ -930,44 +825,37 @@ def test_local_bind_ports(self): self.assertIsInstance(server.local_bind_ports, list) def test_local_bind_hosts(self): - """ Test local_bind_hosts property """ + """Test local_bind_hosts property""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, local_bind_addresses=[(self.saddr, 0)] * 2, - remote_bind_addresses=[(self.eaddr, self.eport), - (self.saddr, self.sport)], + remote_bind_addresses=[(self.eaddr, self.eport), (self.saddr, self.sport)], logger=self.log, ) as server: self.assertIsInstance(server.local_bind_hosts, list) - self.assertListEqual(server.local_bind_hosts, - [self.saddr] * 2) + self.assertListEqual(server.local_bind_hosts, [self.saddr] * 2) with self.assertRaises(sshtunnel.BaseSSHTunnelForwarderError): self.log.info(server.local_bind_host) def test_local_bind_addresses(self): - """ Test local_bind_addresses property """ + """Test local_bind_addresses property""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, local_bind_addresses=[(self.saddr, 0)] * 2, - remote_bind_addresses=[(self.eaddr, self.eport), - (self.saddr, self.sport)], + remote_bind_addresses=[(self.eaddr, self.eport), (self.saddr, self.sport)], logger=self.log, ) as server: self.assertIsInstance(server.local_bind_addresses, list) - self.assertListEqual(server.local_bind_addresses, - list(zip([self.saddr] * 2, - server.local_bind_ports))) + self.assertListEqual(server.local_bind_addresses, list(zip([self.saddr] * 2, server.local_bind_ports))) with self.assertRaises(sshtunnel.BaseSSHTunnelForwarderError): self.log.info(server.local_bind_address) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_check_tunnels(self): - """ Test method checking if tunnels are up """ + """Test method checking if tunnels are up""" remote_address = (self.eaddr, self.eport) with self._test_server( (self.saddr, self.sport), @@ -977,77 +865,74 @@ def test_check_tunnels(self): logger=self.log, skip_tunnel_checkup=False, ) as server: - self.assertIn('Tunnel to {0} is UP'.format(remote_address), - self.sshtunnel_log_messages['debug']) + self.assertIn("Tunnel to {0} is UP".format(remote_address), self.sshtunnel_log_messages["debug"]) server.check_tunnels() - self.assertIn('Tunnel to {0} is DOWN'.format(remote_address), - self.sshtunnel_log_messages['debug']) + self.assertIn("Tunnel to {0} is DOWN".format(remote_address), self.sshtunnel_log_messages["debug"]) # Calling local_is_up() should also return the same server.skip_tunnel_checkup = True server.local_is_up((self.saddr, self.sport)) - self.assertIn('Tunnel to {0} is DOWN'.format(remote_address), - self.sshtunnel_log_messages['debug']) + self.assertIn("Tunnel to {0} is DOWN".format(remote_address), self.sshtunnel_log_messages["debug"]) self.assertFalse(server.local_is_up("not a valid address")) - self.assertIn('Target must be a tuple (IP, port), where IP ' - 'is a string (i.e. "192.168.0.1") and port is ' - 'an integer (i.e. 40000). Alternatively ' - 'target can be a valid UNIX domain socket.', - self.sshtunnel_log_messages['warning']) + self.assertIn( + "Target must be a tuple (IP, port), where IP " + 'is a string (i.e. "192.168.0.1") and port is ' + "an integer (i.e. 40000). Alternatively " + "target can be a valid UNIX domain socket.", + self.sshtunnel_log_messages["warning"], + ) - @mock.patch('sshtunnel.input_', return_value=linesep) + @mock.patch("sshtunnel.input_", return_value=linesep) def test_cli_main_exits_when_pressing_enter(self, input): - """ Test that _cli_main() function quits when Enter is pressed """ + """Test that _cli_main() function quits when Enter is pressed""" + import pytest + self.start_echo_and_ssh_server() - sshtunnel._cli_main(args=[self.saddr, - '-U', SSH_USERNAME, - '-P', SSH_PASSWORD, - '-p', str(self.sport), - '-R', '{0}:{1}'.format(self.eaddr, - self.eport), - '-c', '', - '-n'], - host_pkey_directories=[]) + with pytest.warns(DeprecationWarning, match="ssh_address"): + sshtunnel._cli_main( + args=[ + self.saddr, + "-U", + SSH_USERNAME, + "-P", + SSH_PASSWORD, + "-p", + str(self.sport), + "-R", + "{0}:{1}".format(self.eaddr, self.eport), + "-c", + "", + "-n", + ], + host_pkey_directories=[], + ) self.stop_echo_and_ssh_server() - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_read_private_key_file(self): - """ Test that an encrypted private key can be opened """ - encr_pkey = get_test_data_path(ENCRYPTED_PKEY_FILE) + """Test that an encrypted private key can be opened""" + encr_pkey = self.ssh_keys.encrypted_key_path pkey = sshtunnel.SSHTunnelForwarder.read_private_key_file( - encr_pkey, - pkey_password='sshtunnel', - logger=self.log - ) - _pkey = paramiko.RSAKey.from_private_key_file( - get_test_data_path(PKEY_FILE) + encr_pkey, pkey_password=PKEY_PASSWORD, logger=self.log ) + _pkey = paramiko.RSAKey.from_private_key_file(self.ssh_keys.plain_key_path) self.assertEqual(pkey, _pkey) # Using a wrong password returns None - self.assertIsNone(sshtunnel.SSHTunnelForwarder.read_private_key_file( - encr_pkey, - pkey_password='bad password', - logger=self.log - )) - self.assertIn("Private key file ({0}) could not be loaded as type " - "{1} or bad password" - .format(encr_pkey, type(_pkey)), - self.sshtunnel_log_messages['debug']) + self.assertIsNone( + sshtunnel.SSHTunnelForwarder.read_private_key_file(encr_pkey, pkey_password="bad password", logger=self.log) + ) + self.assertIn( + "Private key file ({0}) could not be loaded as type {1} or bad password".format(encr_pkey, type(_pkey)), + self.sshtunnel_log_messages["debug"], + ) # Using no password on an encrypted key returns None - self.assertIsNone(sshtunnel.SSHTunnelForwarder.read_private_key_file( - encr_pkey, - logger=self.log - )) - self.assertIn('Password is required for key {0}'.format(encr_pkey), - self.sshtunnel_log_messages['error']) - - @unittest.skipIf(os.name != 'posix', - reason="UNIX sockets not supported on this platform") + self.assertIsNone(sshtunnel.SSHTunnelForwarder.read_private_key_file(encr_pkey, logger=self.log)) + self.assertIn("Password is required for key {0}".format(encr_pkey), self.sshtunnel_log_messages["error"]) + + @unittest.skipIf(os.name != "posix", reason="UNIX sockets not supported on this platform") def test_unix_domains(self): - """ Test use of UNIX domain sockets in local binds """ + """Test use of UNIX domain sockets in local binds""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -1058,14 +943,11 @@ def test_unix_domains(self): ) as server: self.assertEqual(server.local_bind_address, TEST_UNIX_SOCKET) - @unittest.skipIf(sys.version_info < (2, 7), - reason="Cannot intercept logging messages in py26") def test_tracing_logging(self): """ Test that Tracing mode may be enabled for more fine-grained logs """ - logger = sshtunnel.create_logger(logger=self.log, - loglevel='TRACE') + logger = sshtunnel.create_logger(logger=self.log, loglevel="TRACE") with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, @@ -1073,21 +955,18 @@ def test_tracing_logging(self): remote_bind_address=(self.eaddr, self.eport), logger=logger, ) as server: - server.logger = sshtunnel.create_logger(logger=server.logger, - loglevel='TRACE') + server.logger = sshtunnel.create_logger(logger=server.logger, loglevel="TRACE") message = get_random_string(100).encode() # Windows raises WinError 10049 if trying to connect to 0.0.0.0 - s = socket.create_connection(('127.0.0.1', server.local_bind_port)) + s = socket.create_connection(("127.0.0.1", server.local_bind_port)) s.send(message) s.recv(100) s.close - log = 'send to {0}'.format((self.eaddr, self.eport)) + log = "send to {0}".format((self.eaddr, self.eport)) - self.assertTrue(any(log in msg for msg in - self.sshtunnel_log_messages['trace'])) + self.assertTrue(any(log in msg for msg in self.sshtunnel_log_messages["trace"])) # set loglevel back to the original value - logger = sshtunnel.create_logger(logger=self.log, - loglevel='DEBUG') + logger = sshtunnel.create_logger(logger=self.log, loglevel="DEBUG") def test_tunnel_bindings_contain_active_tunnels(self): """ @@ -1099,21 +978,13 @@ def test_tunnel_bindings_contain_active_tunnels(self): (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, - remote_bind_addresses=[(self.eaddr, remote_ports[0]), - (self.eaddr, remote_ports[1])], - local_bind_addresses=[('127.0.0.1', local_ports[0]), - ('127.0.0.1', local_ports[1])], + remote_bind_addresses=[(self.eaddr, remote_ports[0]), (self.eaddr, remote_ports[1])], + local_bind_addresses=[("127.0.0.1", local_ports[0]), ("127.0.0.1", local_ports[1])], skip_tunnel_checkup=False, ) as server: self.assertListEqual(server.local_bind_ports, local_ports) - self.assertTupleEqual( - server.tunnel_bindings[(self.eaddr, remote_ports[0])], - ('127.0.0.1', local_ports[0]) - ) - self.assertTupleEqual( - server.tunnel_bindings[(self.eaddr, remote_ports[1])], - ('127.0.0.1', local_ports[1]) - ) + self.assertTupleEqual(server.tunnel_bindings[(self.eaddr, remote_ports[0])], ("127.0.0.1", local_ports[0])) + self.assertTupleEqual(server.tunnel_bindings[(self.eaddr, remote_ports[1])], ("127.0.0.1", local_ports[1])) def check_make_ssh_forward_server_sets_daemon(self, case): self.start_echo_and_ssh_server() @@ -1148,70 +1019,71 @@ def test_make_ssh_forward_server_sets_daemon_false(self): self.check_make_ssh_forward_server_sets_daemon(False) def test_get_keys(self): - """ Test loading keys from the paramiko Agent """ + """Test loading keys from the paramiko Agent""" with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_address=('', self.randomize_eport()), - logger=self.log + local_bind_address=("", self.randomize_eport()), + logger=self.log, ) as server: keys = server.get_keys(logger=self.log) self.assertIsInstance(keys, list) - self.assertFalse(any('keys loaded from agent' in msg for msg in - self.sshtunnel_log_messages['info'])) + self.assertFalse(any("keys loaded from agent" in msg for msg in self.sshtunnel_log_messages["info"])) with self._test_server( (self.saddr, self.sport), ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, remote_bind_address=(self.eaddr, self.eport), - local_bind_address=('', self.randomize_eport()), - logger=self.log + local_bind_address=("", self.randomize_eport()), + logger=self.log, ) as server: keys = server.get_keys(logger=self.log, allow_agent=True) self.assertIsInstance(keys, list) - self.assertTrue(any('keys loaded from agent' in msg for msg in - self.sshtunnel_log_messages['info'])) + self.assertTrue(any("keys loaded from agent" in msg for msg in self.sshtunnel_log_messages["info"])) tmp_dir = tempfile.mkdtemp() - shutil.copy(get_test_data_path(PKEY_FILE), - os.path.join(tmp_dir, 'id_rsa')) + shutil.copy(self.ssh_keys.plain_key_path, os.path.join(tmp_dir, "id_rsa")) keys = sshtunnel.SSHTunnelForwarder.get_keys( self.log, - host_pkey_directories=[tmp_dir, ] + host_pkey_directories=[ + tmp_dir, + ], ) self.assertIsInstance(keys, list) - self.assertTrue( - any('1 key(s) loaded' in msg - for msg in self.sshtunnel_log_messages['info']) - ) + self.assertTrue(any("1 key(s) loaded" in msg for msg in self.sshtunnel_log_messages["info"])) shutil.rmtree(tmp_dir) class AuxiliaryTest(unittest.TestCase): - """ Set of tests that do not need the mock SSH server or logger """ + """Set of tests that do not need the mock SSH server or logger""" def test_parse_arguments_short(self): - """ Test CLI argument parsing with short parameter names """ - args = ['10.10.10.10', # ssh_address - '-U={0}'.format(getpass.getuser()), # GW username - '-p=22', # GW SSH port - '-P={0}'.format(SSH_PASSWORD), # GW password - '-R', '10.0.0.1:8080', '10.0.0.2:8080', # remote bind list - '-L', ':8081', ':8082', # local bind list - '-k={0}'.format(SSH_DSS), # hostkey - '-K={0}'.format(__file__), # pkey file - '-S={0}'.format(SSH_PASSWORD), # pkey password - '-t', # concurrent connections (threaded) - '-vvv', # triple verbosity - '-x=10.0.0.2:', # proxy address - '-c=ssh_config', # ssh configuration file - '-z', # request compression - '-n', # disable SSH agent key lookup - ] + """Test CLI argument parsing with short parameter names""" + args = [ + "10.10.10.10", # ssh_address + "-U={0}".format(getpass.getuser()), # GW username + "-p=22", # GW SSH port + "-P={0}".format(SSH_PASSWORD), # GW password + "-R", + "10.0.0.1:8080", + "10.0.0.2:8080", # remote bind list + "-L", + ":8081", + ":8082", # local bind list + "-k={0}".format(self.ssh_keys.fingerprint), # hostkey + "-K={0}".format(__file__), # pkey file + "-S={0}".format(SSH_PASSWORD), # pkey password + "-t", # concurrent connections (threaded) + "-vvv", # triple verbosity + "-x=10.0.0.2:", # proxy address + "-c=ssh_config", # ssh configuration file + "-z", # request compression + "-n", # disable SSH agent key lookup + ] parser = sshtunnel._parse_arguments(args) self._test_parser(parser) @@ -1224,179 +1096,166 @@ def test_parse_arguments_short(self): parser = sshtunnel._parse_arguments(args[:4] + args[5:]) def test_parse_arguments_long(self): - """ Test CLI argument parsing with long parameter names """ + """Test CLI argument parsing with long parameter names""" parser = sshtunnel._parse_arguments( - ['10.10.10.10', # ssh_address - '--username={0}'.format(getpass.getuser()), # GW username - '--server_port=22', # GW SSH port - '--password={0}'.format(SSH_PASSWORD), # GW password - '--remote_bind_address', '10.0.0.1:8080', '10.0.0.2:8080', - '--local_bind_address', ':8081', ':8082', # local bind list - '--ssh_host_key={0}'.format(SSH_DSS), # hostkey - '--private_key_file={0}'.format(__file__), # pkey file - '--private_key_password={0}'.format(SSH_PASSWORD), - '--threaded', # concurrent connections (threaded) - '--verbose', '--verbose', '--verbose', # triple verbosity - '--proxy', '10.0.0.2:22', # proxy address - '--config', 'ssh_config', # ssh configuration file - '--compress', # request compression - '--noagent', # disable SSH agent key lookup - ] + [ + "10.10.10.10", # ssh_address + "--username={0}".format(getpass.getuser()), # GW username + "--server_port=22", # GW SSH port + "--password={0}".format(SSH_PASSWORD), # GW password + "--remote_bind_address", + "10.0.0.1:8080", + "10.0.0.2:8080", + "--local_bind_address", + ":8081", + ":8082", # local bind list + "--ssh_host_key={0}".format(self.ssh_keys.fingerprint), # hostkey + "--private_key_file={0}".format(__file__), # pkey file + "--private_key_password={0}".format(SSH_PASSWORD), + "--threaded", # concurrent connections (threaded) + "--verbose", + "--verbose", + "--verbose", # triple verbosity + "--proxy", + "10.0.0.2:22", # proxy address + "--config", + "ssh_config", # ssh configuration file + "--compress", # request compression + "--noagent", # disable SSH agent key lookup + ] ) self._test_parser(parser) def _test_parser(self, parser): - self.assertEqual(parser['ssh_address'], '10.10.10.10') - self.assertEqual(parser['ssh_username'], getpass.getuser()) - self.assertEqual(parser['ssh_port'], 22) - self.assertEqual(parser['ssh_password'], SSH_PASSWORD) - self.assertListEqual(parser['remote_bind_addresses'], - [('10.0.0.1', 8080), ('10.0.0.2', 8080)]) - self.assertListEqual(parser['local_bind_addresses'], - [('', 8081), ('', 8082)]) - self.assertEqual(parser['ssh_host_key'], str(SSH_DSS)) - self.assertEqual(parser['ssh_private_key'], __file__) - self.assertEqual(parser['ssh_private_key_password'], SSH_PASSWORD) - self.assertTrue(parser['threaded']) - self.assertEqual(parser['verbose'], 3) - self.assertEqual(parser['ssh_proxy'], ('10.0.0.2', 22)) - self.assertEqual(parser['ssh_config_file'], 'ssh_config') - self.assertTrue(parser['compression']) - self.assertFalse(parser['allow_agent']) + self.assertEqual(parser["ssh_address"], "10.10.10.10") + self.assertEqual(parser["ssh_username"], getpass.getuser()) + self.assertEqual(parser["ssh_port"], 22) + self.assertEqual(parser["ssh_password"], SSH_PASSWORD) + self.assertListEqual(parser["remote_bind_addresses"], [("10.0.0.1", 8080), ("10.0.0.2", 8080)]) + self.assertListEqual(parser["local_bind_addresses"], [("", 8081), ("", 8082)]) + self.assertEqual(parser["ssh_host_key"], str(self.ssh_keys.fingerprint)) + self.assertEqual(parser["ssh_private_key"], __file__) + self.assertEqual(parser["ssh_private_key_password"], SSH_PASSWORD) + self.assertTrue(parser["threaded"]) + self.assertEqual(parser["verbose"], 3) + self.assertEqual(parser["ssh_proxy"], ("10.0.0.2", 22)) + self.assertEqual(parser["ssh_config_file"], "ssh_config") + self.assertTrue(parser["compression"]) + self.assertFalse(parser["allow_agent"]) def test_bindlist(self): """ Test that _bindlist enforces IP:PORT format for local and remote binds """ - self.assertTupleEqual(sshtunnel._bindlist('10.0.0.1:8080'), - ('10.0.0.1', 8080)) + self.assertTupleEqual(sshtunnel._bindlist("10.0.0.1:8080"), ("10.0.0.1", 8080)) # Missing port in tuple is filled with port 22 - self.assertTupleEqual(sshtunnel._bindlist('10.0.0.1:'), - ('10.0.0.1', 22)) - self.assertTupleEqual(sshtunnel._bindlist('10.0.0.1'), - ('10.0.0.1', 22)) + self.assertTupleEqual(sshtunnel._bindlist("10.0.0.1:"), ("10.0.0.1", 22)) + self.assertTupleEqual(sshtunnel._bindlist("10.0.0.1"), ("10.0.0.1", 22)) with self.assertRaises(argparse.ArgumentTypeError): - sshtunnel._bindlist('10022:10.0.0.1:22') + sshtunnel._bindlist("10022:10.0.0.1:22") with self.assertRaises(argparse.ArgumentTypeError): - sshtunnel._bindlist(':') + sshtunnel._bindlist(":") def test_raise_fwd_ext(self): - """ Test that we can silence the exceptions on sshtunnel creation """ + """Test that we can silence the exceptions on sshtunnel creation""" server = open_tunnel( - '10.10.10.10', + "10.10.10.10", ssh_username=SSH_USERNAME, ssh_password=SSH_PASSWORD, - remote_bind_address=('10.0.0.1', 8080), + remote_bind_address=("10.0.0.1", 8080), mute_exceptions=True, ) # This should not raise an exception - server._raise(sshtunnel.BaseSSHTunnelForwarderError, 'test') + server._raise(sshtunnel.BaseSSHTunnelForwarderError, "test") server._raise_fwd_exc = True # now exceptions are not silenced with self.assertRaises(sshtunnel.BaseSSHTunnelForwarderError): - server._raise(sshtunnel.BaseSSHTunnelForwarderError, 'test') + server._raise(sshtunnel.BaseSSHTunnelForwarderError, "test") def test_show_running_version(self): - """ Test that _cli_main() function quits when Enter is pressed """ + """Test that _cli_main() function quits when Enter is pressed""" with capture_stdout_stderr() as (out, err): with self.assertRaises(SystemExit): - sshtunnel._cli_main(args=['-V']) - if sys.version_info < (3, 4): - version = err.getvalue().split()[-1] - else: - version = out.getvalue().split()[-1] - self.assertEqual(version, - sshtunnel.__version__) + sshtunnel._cli_main(args=["-V"]) + version = out.getvalue().split()[-1] + self.assertEqual(version, sshtunnel.__version__) def test_remove_none_values(self): - """ Test removing keys from a dict where values are None """ - test_dict = {'key1': 1, 'key2': None, 'key3': 3, 'key4': 0} + """Test removing keys from a dict where values are None""" + test_dict = {"key1": 1, "key2": None, "key3": 3, "key4": 0} sshtunnel._remove_none_values(test_dict) - self.assertDictEqual(test_dict, - {'key1': 1, 'key3': 3, 'key4': 0}) + self.assertDictEqual(test_dict, {"key1": 1, "key3": 3, "key4": 0}) def test_read_ssh_config(self): - """ Test that we can gather host information from a config file """ - (ssh_hostname, - ssh_username, - ssh_private_key, - ssh_port, - ssh_proxy, - compression) = sshtunnel.SSHTunnelForwarder._read_ssh_config( - 'test', - get_test_data_path(TEST_CONFIG_FILE), + """Test that we can gather host information from a config file""" + (ssh_hostname, ssh_username, ssh_private_key, ssh_port, ssh_proxy, compression) = ( + sshtunnel.SSHTunnelForwarder._read_ssh_config( + "test", + self.ssh_keys.config_path, + ) ) - self.assertEqual(ssh_hostname, 'test') - self.assertEqual(ssh_username, 'test') - self.assertEqual(PKEY_FILE, ssh_private_key) + self.assertEqual(ssh_hostname, "test") + self.assertEqual(ssh_username, "test") + self.assertEqual(self.ssh_keys.plain_key_path, ssh_private_key) self.assertEqual(ssh_port, 22) # fallback value - self.assertListEqual(ssh_proxy.cmd[-2:], ['test:22', 'sshproxy']) + self.assertListEqual(ssh_proxy.cmd[-2:], ["test:22", "sshproxy"]) self.assertTrue(compression) # passed parameters are not overriden by config - (ssh_hostname, - ssh_username, - ssh_private_key, - ssh_port, - ssh_proxy, - compression) = sshtunnel.SSHTunnelForwarder._read_ssh_config( - 'other', - get_test_data_path(TEST_CONFIG_FILE), - compression=False + (ssh_hostname, ssh_username, ssh_private_key, ssh_port, ssh_proxy, compression) = ( + sshtunnel.SSHTunnelForwarder._read_ssh_config("other", self.ssh_keys.config_path, compression=False) ) - self.assertEqual(ssh_hostname, '10.0.0.1') + self.assertEqual(ssh_hostname, "10.0.0.1") self.assertEqual(ssh_port, 222) self.assertFalse(compression) def test_str(self): server = open_tunnel( - 'test', - ssh_private_key=get_test_data_path(PKEY_FILE), - remote_bind_address=('10.0.0.1', 8080), + "test", + ssh_pkey=self.ssh_keys.plain_key_path, + remote_bind_address=("10.0.0.1", 8080), ) _str = str(server).split(linesep) self.assertEqual(repr(server), str(server)) - self.assertIn('ssh gateway: test:22', _str) - self.assertIn('proxy: no', _str) - self.assertIn('username: {0}'.format(getpass.getuser()), _str) - self.assertIn('status: not started', _str) + self.assertIn("ssh gateway: test:22", _str) + self.assertIn("proxy: no", _str) + self.assertIn("username: {0}".format(getpass.getuser()), _str) + self.assertIn("status: not started", _str) def test_process_deprecations(self): - """ Test processing deprecated API attributes """ - kwargs = {'ssh_host': '10.0.0.1', - 'ssh_address': '10.0.0.1', - 'ssh_private_key': 'testrsa.key', - 'raise_exception_if_any_forwarder_have_a_problem': True} + """Test processing deprecated API attributes""" + import pytest + + kwargs = { + "ssh_host": "10.0.0.1", + "ssh_address": "10.0.0.1", + "ssh_private_key": "testrsa.key", + "raise_exception_if_any_forwarder_have_a_problem": True, + } for item in kwargs: - self.assertEqual(kwargs[item], - sshtunnel.SSHTunnelForwarder._process_deprecated( - None, - item, - kwargs.copy() - )) + with pytest.warns(DeprecationWarning, match="is DEPRECATED"): + self.assertEqual( + kwargs[item], sshtunnel.SSHTunnelForwarder._process_deprecated(None, item, kwargs.copy()) + ) # use both deprecated and not None new attribute should raise exception for item in kwargs: - with self.assertRaises(ValueError): - sshtunnel.SSHTunnelForwarder._process_deprecated('some value', - item, - kwargs.copy()) + with pytest.warns(DeprecationWarning, match="is DEPRECATED"): + with self.assertRaises(ValueError): + sshtunnel.SSHTunnelForwarder._process_deprecated("some value", item, kwargs.copy()) # deprecated attribute not in deprecation list should raise exception with self.assertRaises(ValueError): - sshtunnel.SSHTunnelForwarder._process_deprecated('some value', - 'item', - kwargs.copy()) + sshtunnel.SSHTunnelForwarder._process_deprecated("some value", "item", kwargs.copy()) def test_check_address(self): - """ Test that an exception is raised with incorrect bind addresses """ - address_list = [('10.0.0.1', 10000), - ('10.0.0.1', 10001)] - if os.name == 'posix': # UNIX sockets supported by the platform - address_list.append('/tmp/unix-socket') + """Test that an exception is raised with incorrect bind addresses""" + address_list = [("10.0.0.1", 10000), ("10.0.0.1", 10001)] + if os.name == "posix": # UNIX sockets supported by the platform + address_list.append("/tmp/unix-socket") # UNIX sockets not supported on remote addresses with self.assertRaises(AssertionError): sshtunnel.check_addresses(address_list, is_remote=True) self.assertIsNone(sshtunnel.check_addresses(address_list)) with self.assertRaises(ValueError): - sshtunnel.check_address('this is not valid') + sshtunnel.check_address("this is not valid") with self.assertRaises(ValueError): sshtunnel.check_address(-1) # that's not valid either diff --git a/tests/testconfig b/tests/testconfig deleted file mode 100644 index 9949917a..00000000 --- a/tests/testconfig +++ /dev/null @@ -1,9 +0,0 @@ -Host * - User test - Compression yes - IdentityFile testrsa.key -Host test - ProxyCommand ssh -q -W %h:%p sshproxy -Host other - Port 222 - Hostname 10.0.0.1 diff --git a/tests/testrsa.key b/tests/testrsa.key deleted file mode 100644 index f50e9c53..00000000 --- a/tests/testrsa.key +++ /dev/null @@ -1,15 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIICWgIBAAKBgQDTj1bqB4WmayWNPB+8jVSYpZYk80Ujvj680pOTh2bORBjbIAyz -oWGW+GUjzKxTiiPvVmxFgx5wdsFvF03v34lEVVhMpouqPAYQ15N37K/ir5XY+9m/ -d8ufMCkjeXsQkKqFbAlQcnWMCRnOoPHS3I4vi6hmnDDeeYTSRvfLbW0fhwIBIwKB -gBIiOqZYaoqbeD9OS9z2K9KR2atlTxGxOJPXiP4ESqP3NVScWNwyZ3NXHpyrJLa0 -EbVtzsQhLn6rF+TzXnOlcipFvjsem3iYzCpuChfGQ6SovTcOjHV9z+hnpXvQ/fon -soVRZY65wKnF7IAoUwTmJS9opqgrN6kRgCd3DASAMd1bAkEA96SBVWFt/fJBNJ9H -tYnBKZGw0VeHOYmVYbvMSstssn8un+pQpUm9vlG/bp7Oxd/m+b9KWEh2xPfv6zqU -avNwHwJBANqzGZa/EpzF4J8pGti7oIAPUIDGMtfIcmqNXVMckrmzQ2vTfqtkEZsA -4rE1IERRyiJQx6EJsz21wJmGV9WJQ5kCQQDwkS0uXqVdFzgHO6S++tjmjYcxwr3g -H0CoFYSgbddOT6miqRskOQF3DZVkJT3kyuBgU2zKygz52ukQZMqxCb1fAkASvuTv -qfpH87Qq5kQhNKdbbwbmd2NxlNabazPijWuphGTdW0VfJdWfklyS2Kr+iqrs/5wV -HhathJt636Eg7oIjAkA8ht3MQ+XSl9yIJIS8gVpbPxSw5OMfw0PjVE7tBdQruiSc -nvuQES5C9BMHjF39LZiGH1iLQy7FgdHyoP+eodI7 ------END RSA PRIVATE KEY----- diff --git a/tests/testrsa_encrypted.key b/tests/testrsa_encrypted.key deleted file mode 100644 index d789a8fb..00000000 --- a/tests/testrsa_encrypted.key +++ /dev/null @@ -1,18 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -Proc-Type: 4,ENCRYPTED -DEK-Info: AES-128-CBC,6FBF59B34F7A4C1BC7566EAB42D21A9B - -wk9i86OJEt1ayS1Vl08bED2n7aPK10DX/wtIjEixDeUWTPofy9HiEeWVsDxKacTp -hLR5bU5p0rzSQ/BXwADAzAaFvqkzOJ2eAz8zY9PC1hE52vniW6Z5jtTvdSyysaEP -g26Ut88EfASajqPmZJtVoViL+epv4X+sbhl8Ssh/3jLZSH+Ay3Sz2tgPXsPMbgOY -/GH73O1AolFMbm9EwxRP1RlzCnHrM2+6cVOknw5t12biSpOtVZyFf+rKPf14+r6T -/+TyBSGY3fh0LH5w8ro+s6VIuEdhJSY+CXbQ6G43vv6uTMIIZ0cQoFu2XZf7TgHm -fqfnC9Ttzy6bvpuA+WjnYkEibxW4T7TJsLjpIiesFaWn6NbhVyLrv2j9Zs+80VkS -2ue9zBGVtJOXQaRkafi7r/e3eDp8twZrfujWg5cA6RU2qF3/IzC4m5P66aRd2nwT -njgY1mrNSn+6ZLnnIV4vJ6I8RB3kctIA06a9pOWMtrKrnayLKSfpntIoYczjgHsN -rDTFgHg84u+1GWRNYaBLWaEbDPeewtc2Zi7pZQz8xGpK97NYvaok171bXg8nGfsy -Qj67/AcRSNH9l5NX1jxlj5RF7UILaS1xfNNU85w/L2vlt5zIGcTvHf54azPQjGNO -RE5d5ePea21DgX+jkxucvA9jmhiXKnvBBUg8BfQuaQQ/f9Voktk+ZSfHYabf89y3 -D+sWsl708JyuQr6hwDEb7qwv3A/cb867WFrXkptj8OBfgIAyQilTaLDjj3XuNHMC -jbr6rqbn55NP8TVdz9O1MfoeQsJxDYcCa7l3n2i6gnU= ------END RSA PRIVATE KEY----- diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 3baf2bb6..00000000 --- a/tox.ini +++ /dev/null @@ -1,40 +0,0 @@ -[tox] -envlist = syntax, py{27,34,35,36,37,38}, docs - -[testenv] -deps = - paramiko - -r{toxinidir}/tests/requirements.txt -commands = - py.test tests \ - --showlocals \ - --cov sshtunnel \ - --cov-report=term \ - --cov-report=html \ - --durations=10 \ - -n4 -W ignore::DeprecationWarning - -[testenv:docs] -changedir = docs -deps = - paramiko - -r{toxinidir}/docs/requirements.txt -commands= - sphinx-build -WavE -b html -d {envtmpdir}/_build/doctrees . {envtmpdir}/_build/html - -[testenv:syntax] -basepython = python -skip_install = True -deps = - -r{toxinidir}/tests/requirements-syntax.txt -commands = - check-manifest --ignore "tox.ini,tests*,*.yml" - python setup.py sdist - twine check dist/* - flake8 --ignore=W504 - bashtest README.rst - -[flake8] -exclude = .tox,*.egg,build,data,docs -select = E,W,F -max-complexity = 10