Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/shade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@
"ShadeError",
"SyncHTTPClient",
"api_base",
"environment",
"max_retries",
"timeout",
]


class _ShadeModule(ModuleType):
"""Module subclass that exposes config-backed attributes on the shade package."""

Expand Down Expand Up @@ -72,5 +72,15 @@ def max_retries(self, value: int) -> None:
from . import config as _config
_config.max_retries = value

@property
def environment(self) -> Environment:
from . import config as _config
return _config.environment

@environment.setter
def environment(self, value: str | Environment) -> None:
from . import config as _config
_config.environment = _config.parse_environment(value)


sys.modules[__name__].__class__ = _ShadeModule
33 changes: 27 additions & 6 deletions src/shade/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,42 @@ def validate_client_settings(timeout: float, max_retries: int) -> None:


class Environment(str, Enum):
MAINNET = "mainnet"
TESTNET = "testnet"
SANDBOX = "sandbox"
PRODUCTION = "production"

@property
def base_url(self) -> str:
_urls: dict[str, str] = {
"mainnet": "https://api.shadeprotocol.io/v1",
"testnet": "https://testnet.api.shadeprotocol.io/v1",
"sandbox": "https://testnet.api.shadeprotocol.io/v1",
"production": "https://api.shadeprotocol.io/v1",
}
return _urls[self.value]

@property
def network_passphrase(self) -> str:
_passphrases: dict[str, str] = {
"mainnet": Network.PUBLIC_NETWORK_PASSPHRASE,
"testnet": Network.TESTNET_NETWORK_PASSPHRASE,
"sandbox": Network.TESTNET_NETWORK_PASSPHRASE,
"production": Network.PUBLIC_NETWORK_PASSPHRASE,
}
return _passphrases[self.value]

@property
def horizon_url(self) -> str:
_horizons: dict[str, str] = {
"sandbox": "https://horizon-testnet.stellar.org",
"production": "https://horizon.stellar.org",
}
return _horizons[self.value]


def parse_environment(value: str | Environment) -> Environment:
if isinstance(value, Environment):
return value
if isinstance(value, str):
try:
return Environment(value.lower())
except ValueError:
pass
raise ValueError("Invalid environment. Valid options are: 'sandbox', 'production'")

environment: Environment = Environment.SANDBOX
15 changes: 9 additions & 6 deletions src/shade/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class Gateway:
----------
api_key : str
Your Shade API key.
environment : Environment
environment : str | Environment, optional
Controls the Stellar network passphrase and the default API URL.
Defaults to ``Environment.MAINNET``.
Defaults to the module-level ``shade.environment`` (``Environment.SANDBOX``).
api_base : str, optional
Override the API host for this client (useful for local dev or staging).
Takes precedence over the module-level ``shade.api_base`` and the
Expand All @@ -37,7 +37,7 @@ class Gateway:
def __init__(
self,
api_key: str = "",
environment: Environment = Environment.MAINNET,
environment: Optional[Environment | str] = None,
api_base: Optional[str] = None,
base_url: str = "",
max_retries: Optional[int] = None,
Expand All @@ -46,7 +46,11 @@ def __init__(
if not api_key:
raise ValueError("api_key must be a non-empty string")
self.api_key = api_key
self.environment = environment

if environment is not None:
self.environment = _config.parse_environment(environment)
else:
self.environment = _config.environment

resolved_max_retries = (
_config.max_retries if max_retries is None else max_retries
Expand All @@ -56,9 +60,8 @@ def __init__(

# Resolution order: explicit api_base > module-level shade.api_base
# > legacy base_url > environment URL
resolved = api_base or _config.api_base or base_url or environment.base_url
resolved = api_base or _config.api_base or base_url or self.environment.base_url
self._base_url = resolved.rstrip("/")

self._http = SyncHTTPClient(
base_url=self._base_url,
api_key=api_key,
Expand Down
80 changes: 60 additions & 20 deletions tests/test_api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,45 @@ def test_reset_to_none_restores_environment_url(self):
shade.api_base = "https://staging.shadeprotocol.io"
shade.api_base = None
gw = Gateway(api_key="test-key")
assert gw._base_url == Environment.MAINNET.base_url
assert gw._base_url == Environment.SANDBOX.base_url # default is SANDBOX now


# ---------------------------------------------------------------------------
# Module-level shade.environment
# ---------------------------------------------------------------------------

class TestModuleLevelEnvironment:
@pytest.fixture(autouse=True)
def _reset_environment(self):
original = shade.environment
yield
shade.environment = original

def test_defaults_to_sandbox(self):
assert shade.environment == Environment.SANDBOX
assert _config.environment == Environment.SANDBOX

def test_assignment_updates_config(self):
shade.environment = "production"
assert shade.environment == Environment.PRODUCTION
assert _config.environment == Environment.PRODUCTION

def test_invalid_assignment_raises_value_error(self):
with pytest.raises(ValueError, match="Invalid environment. Valid options are: 'sandbox', 'production'"):
shade.environment = "invalid"

def test_used_by_gateway_by_default(self):
shade.environment = "production"
gw = Gateway(api_key="test-key")
assert gw.environment == Environment.PRODUCTION
assert gw._base_url == Environment.PRODUCTION.base_url

def test_per_client_override(self):
shade.environment = "production"
gw = Gateway(api_key="test-key", environment="sandbox")
assert gw.environment == Environment.SANDBOX
assert gw._base_url == Environment.SANDBOX.base_url

# ---------------------------------------------------------------------------
# Per-client api_base
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -87,21 +123,21 @@ def test_http_client_uses_resolved_base_url(self):
# ---------------------------------------------------------------------------

class TestEnvironmentPassphrase:
def test_mainnet_passphrase_unchanged_when_api_base_set(self):
def test_production_passphrase_unchanged_when_api_base_set(self):
from stellar_sdk import Network
gw = Gateway(
api_key="test-key",
api_base="http://localhost:8000",
environment=Environment.MAINNET,
environment=Environment.PRODUCTION,
)
assert gw.environment.network_passphrase == Network.PUBLIC_NETWORK_PASSPHRASE

def test_testnet_passphrase_unchanged_when_api_base_set(self):
def test_sandbox_passphrase_unchanged_when_api_base_set(self):
from stellar_sdk import Network
gw = Gateway(
api_key="test-key",
api_base="http://localhost:8000",
environment=Environment.TESTNET,
environment=Environment.SANDBOX,
)
assert gw.environment.network_passphrase == Network.TESTNET_NETWORK_PASSPHRASE

Expand All @@ -110,7 +146,7 @@ def test_api_base_overrides_url_not_passphrase(self):
gw = Gateway(
api_key="test-key",
api_base="http://localhost:8000",
environment=Environment.MAINNET,
environment=Environment.PRODUCTION,
)
assert gw._base_url == "http://localhost:8000"
assert gw.environment.network_passphrase == Network.PUBLIC_NETWORK_PASSPHRASE
Expand All @@ -122,42 +158,46 @@ def test_api_base_overrides_url_not_passphrase(self):

class TestUrlResolutionPrecedence:
def test_environment_url_is_default(self):
gw = Gateway(api_key="test-key", environment=Environment.MAINNET)
assert gw._base_url == Environment.MAINNET.base_url
gw = Gateway(api_key="test-key", environment=Environment.PRODUCTION)
assert gw._base_url == Environment.PRODUCTION.base_url

def test_module_level_beats_environment(self):
shade.api_base = "https://staging.shadeprotocol.io"
gw = Gateway(api_key="test-key", environment=Environment.MAINNET)
gw = Gateway(api_key="test-key", environment=Environment.PRODUCTION)
assert gw._base_url == "https://staging.shadeprotocol.io"

def test_per_client_beats_module_level(self):
shade.api_base = "https://staging.shadeprotocol.io"
gw = Gateway(api_key="test-key", api_base="http://localhost:8000")
assert gw._base_url == "http://localhost:8000"

def test_testnet_environment_url_used_by_default(self):
gw = Gateway(api_key="test-key", environment=Environment.TESTNET)
assert gw._base_url == Environment.TESTNET.base_url
def test_sandbox_environment_url_used_by_default(self):
gw = Gateway(api_key="test-key", environment=Environment.SANDBOX)
assert gw._base_url == Environment.SANDBOX.base_url


# ---------------------------------------------------------------------------
# Environment enum
# ---------------------------------------------------------------------------

class TestEnvironment:
def test_mainnet_base_url(self):
assert Environment.MAINNET.base_url == "https://api.shadeprotocol.io/v1"
def test_production_base_url(self):
assert Environment.PRODUCTION.base_url == "https://api.shadeprotocol.io/v1"

def test_testnet_base_url(self):
assert Environment.TESTNET.base_url == "https://testnet.api.shadeprotocol.io/v1"
def test_sandbox_base_url(self):
assert Environment.SANDBOX.base_url == "https://testnet.api.shadeprotocol.io/v1"

def test_mainnet_network_passphrase(self):
def test_production_network_passphrase(self):
from stellar_sdk import Network
assert Environment.MAINNET.network_passphrase == Network.PUBLIC_NETWORK_PASSPHRASE
assert Environment.PRODUCTION.network_passphrase == Network.PUBLIC_NETWORK_PASSPHRASE

def test_testnet_network_passphrase(self):
def test_sandbox_network_passphrase(self):
from stellar_sdk import Network
assert Environment.TESTNET.network_passphrase == Network.TESTNET_NETWORK_PASSPHRASE
assert Environment.SANDBOX.network_passphrase == Network.TESTNET_NETWORK_PASSPHRASE

def test_horizon_urls(self):
assert Environment.PRODUCTION.horizon_url == "https://horizon.stellar.org"
assert Environment.SANDBOX.horizon_url == "https://horizon-testnet.stellar.org"


# ---------------------------------------------------------------------------
Expand Down
Loading