From 91fc46c0f9745674e4ab61cb05ba7aecf2c2461d Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 7 Apr 2026 11:25:09 -0500 Subject: [PATCH 1/2] refactor clients --- README.md | 2 +- examples/decimals.py | 2 +- examples/deposit.py | 2 +- examples/rest_usage.py | 2 +- examples/settlement.py | 2 +- examples/two_users_trade.py | 4 +- examples/vaults.py | 2 +- examples/websocket_usage.py | 2 +- pyproject.toml | 3 + tests/client/test_auth.py | 242 +++++++++++++++ tests/client/test_orderbook_control_ws.py | 5 +- tests/client/test_orders_404.py | 2 +- tests/client/test_tls.py | 21 +- tests/integration/test_clearing_engine.py | 2 +- tplus/client/auth.py | 140 +++++++++ tplus/client/base.py | 348 ++++++++++------------ tplus/client/clearingengine/__init__.py | 23 +- tplus/client/oms/oms_admin.py | 4 +- tplus/client/orderbook.py | 145 +++++---- tplus/evm/managers/chaindata.py | 4 +- tplus/evm/managers/settle.py | 4 +- tplus/exceptions.py | 17 ++ tplus/types.py | 4 + 23 files changed, 694 insertions(+), 288 deletions(-) create mode 100644 tests/client/test_auth.py create mode 100644 tplus/client/auth.py create mode 100644 tplus/exceptions.py create mode 100644 tplus/types.py diff --git a/README.md b/README.md index d020b01..cc146f0 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ user = User() async def run_client(): # Use async context manager for automatic cleanup - async with OrderBookClient(user=user, base_url=API_BASE_URL) as client: + async with OrderBookClient(API_BASE_URL, default_user=user) as client: print("Client initialized.") # ... use client methods ... diff --git a/examples/decimals.py b/examples/decimals.py index 8ed9ce3..6b676cf 100644 --- a/examples/decimals.py +++ b/examples/decimals.py @@ -10,7 +10,7 @@ async def main(): - client = ClearingEngineClient(User(), CLEARING_ENGINE_HOST) + client = ClearingEngineClient(CLEARING_ENGINE_HOST, default_user=User()) assets: list[AssetIdentifier | str] = [ AssetIdentifier("0xf3c3351d6bd0098eeb33ca8f830faf2a141ea2e1@421614") ] diff --git a/examples/deposit.py b/examples/deposit.py index b56607d..6e5bd6b 100644 --- a/examples/deposit.py +++ b/examples/deposit.py @@ -51,7 +51,7 @@ async def main(): ) # Connect to the t+ clearing engine. - client = ClearingEngineClient(tplus_user, CLEARING_ENGINE_HOST) + client = ClearingEngineClient(CLEARING_ENGINE_HOST, default_user=tplus_user) deposit_to_chain(blockchain_user, tplus_user) await deposit_to_ce(tplus_user, client) diff --git a/examples/rest_usage.py b/examples/rest_usage.py index 1e4cf75..a9cc893 100644 --- a/examples/rest_usage.py +++ b/examples/rest_usage.py @@ -32,7 +32,7 @@ async def main(): # Initialize client with the API base URL # Using async context manager ensures the client connection is closed properly try: - async with OrderBookClient(user, base_url=API_BASE_URL) as client: + async with OrderBookClient(API_BASE_URL, default_user=user) as client: logger.info("Client initialized.") # --- Simple GET Test First --- diff --git a/examples/settlement.py b/examples/settlement.py index fa2c7dd..a998861 100644 --- a/examples/settlement.py +++ b/examples/settlement.py @@ -27,7 +27,7 @@ async def init_settlement(client, tplus_user): async def main(): tplus_user = load_user(USERNAME) - client = ClearingEngineClient(tplus_user, CLEARING_ENGINE_HOST) + client = ClearingEngineClient(CLEARING_ENGINE_HOST, default_user=tplus_user) await init_settlement(client, tplus_user) diff --git a/examples/two_users_trade.py b/examples/two_users_trade.py index 0b75454..37ce320 100644 --- a/examples/two_users_trade.py +++ b/examples/two_users_trade.py @@ -33,8 +33,8 @@ async def main() -> None: logger.info("Connecting to OMS at %s", API_BASE_URL) async with ( - OrderBookClient(user_a, base_url=API_BASE_URL) as client_a, - OrderBookClient(user_b, base_url=API_BASE_URL) as client_b, + OrderBookClient(API_BASE_URL, default_user=user_a) as client_a, + OrderBookClient(API_BASE_URL, default_user=user_b) as client_b, ): # ------------------------------------------------------------------- # Ensure the market exists (idempotent – returns 409 if already there) diff --git a/examples/vaults.py b/examples/vaults.py index cb3a36d..7177f83 100644 --- a/examples/vaults.py +++ b/examples/vaults.py @@ -10,7 +10,7 @@ async def main(): tplus_user = load_user(USERNAME) - client = ClearingEngineClient(tplus_user, CLEARING_ENGINE_HOST) + client = ClearingEngineClient(CLEARING_ENGINE_HOST, default_user=tplus_user) vault_addresses = await client.vaults.get() pprint(vault_addresses) diff --git a/examples/websocket_usage.py b/examples/websocket_usage.py index a3cbba1..7c583b3 100644 --- a/examples/websocket_usage.py +++ b/examples/websocket_usage.py @@ -100,7 +100,7 @@ async def main(): # Removed signal handling setup - Not supported on Windows default loop try: - async with OrderBookClient(user, base_url=API_BASE_URL) as client: + async with OrderBookClient(API_BASE_URL, default_user=user) as client: logger.info("Client initialized.") # Create tasks for the stream listeners diff --git a/pyproject.toml b/pyproject.toml index 51658f9..313a2ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,9 @@ ignore = [ "S501", # ADD IT LATER ] +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["S106"] + [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py new file mode 100644 index 0000000..1d8a6a0 --- /dev/null +++ b/tests/client/test_auth.py @@ -0,0 +1,242 @@ +import asyncio +import time + +import httpx +import pytest + +from tplus.client.auth import Auth, AuthenticatedClient +from tplus.client.base import BaseClient, ClientSettings +from tplus.exceptions import MissingClientUserError +from tplus.model.types import UserPublicKey + + +class TestAuth: + def test_new_auth_is_expired(self): + auth = Auth() + assert auth.is_expired() is True + + def test_auth_with_token_but_zero_expiry_is_expired(self): + auth = Auth(token="some-token") + assert auth.is_expired() is True + + def test_auth_with_valid_token_not_expired(self): + auth = Auth(token="some-token") + # Set expiry far in the future. + auth.expiry_ns = time.time_ns() + (120 * 1_000_000_000) + assert auth.is_expired() is False + + def test_auth_expired_within_safety_margin(self): + auth = Auth(token="some-token") + # Set expiry just 30s from now (within 60s safety margin). + auth.expiry_ns = time.time_ns() + (30 * 1_000_000_000) + assert auth.is_expired() is True + + def test_auth_expired_exactly_at_margin(self): + auth = Auth(token="some-token") + # Set expiry exactly at the margin boundary. + auth.expiry_ns = time.time_ns() + Auth.SAFETY_MARGIN_NS + assert auth.is_expired() is True + + def test_auth_has_lock(self): + auth = Auth() + assert isinstance(auth.lock, asyncio.Lock) + + +class TestBaseClient: + def _make_client(self, **settings_kwargs) -> BaseClient: + return BaseClient(ClientSettings(**settings_kwargs)) + + def test_constructor_with_settings(self): + client = self._make_client(base_url="http://localhost:9999") + assert isinstance(client._client, httpx.AsyncClient) + assert client._settings.base_url == "http://localhost:9999" + + def test_constructor_with_url_string(self): + client = BaseClient("http://localhost:9999") + assert client._settings.base_url == "http://localhost:9999" + + def test_from_client_shares_internals(self): + parent = self._make_client() + child = BaseClient.from_client(parent) + assert child._client is parent._client + assert child._settings is parent._settings + + def test_validate_user_with_no_default_raises(self): + client = self._make_client() + with pytest.raises(MissingClientUserError): + client._validate_user() + + def test_validate_user_returns_default(self): + class FakeUser: + public_key = "abc" + + client = BaseClient(ClientSettings(), default_user=FakeUser()) # type: ignore + assert client._validate_user().public_key == "abc" # type: ignore + + def test_validate_user_prefers_explicit(self): + class FakeUser: + public_key = "abc" + + class OtherUser: + public_key = "xyz" + + client = BaseClient(ClientSettings(), default_user=FakeUser()) # type: ignore + assert client._validate_user(user=OtherUser()).public_key == "xyz" # type: ignore + + def test_validate_user_public_key_from_string(self): + key = UserPublicKey("ab" * 32) + client = BaseClient(ClientSettings()) + assert client._validate_user_public_key(key) == key + + def test_validate_user_public_key_from_user(self): + from tplus.utils.user import User + + user = User() + client = BaseClient(ClientSettings()) + assert client._validate_user_public_key(user) == user.public_key + + def test_validate_user_public_key_no_default_raises(self): + client = self._make_client() + with pytest.raises(MissingClientUserError): + client._validate_user_public_key() + + def test_get_request_headers_returns_settings_headers(self): + client = self._make_client() + headers = client._get_request_headers() + assert headers["Content-Type"] == "application/json" + assert headers["Accept"] == "application/json" + + def test_get_request_headers_returns_copy(self): + client = self._make_client() + h1 = client._get_request_headers() + h1["X-Custom"] = "foo" + h2 = client._get_request_headers() + assert "X-Custom" not in h2 + + def test_get_websocket_url_ws(self): + client = self._make_client(base_url="http://localhost:3032") + assert client._get_websocket_url("/stream") == "ws://localhost:3032/stream" + + def test_get_websocket_url_wss(self): + client = self._make_client(base_url="https://example.com") + assert client._get_websocket_url("/stream") == "wss://example.com/stream" + + def test_get_websocket_url_no_leading_slash(self): + client = self._make_client(base_url="http://localhost:3032") + assert client._get_websocket_url("stream") == "ws://localhost:3032/stream" + + def test_handle_response_204(self): + req = httpx.Request("GET", "http://example.com/test") + resp = httpx.Response(204, request=req) + client = self._make_client() + assert client._handle_response(resp) == {} + + def test_handle_response_empty_body(self): + req = httpx.Request("GET", "http://example.com/test") + resp = httpx.Response(200, request=req, content=b"") + client = self._make_client() + assert client._handle_response(resp) == {} + + def test_handle_response_json(self): + req = httpx.Request("GET", "http://example.com/test") + resp = httpx.Response(200, request=req, json={"key": "value"}) + client = self._make_client() + assert client._handle_response(resp) == {"key": "value"} + + def test_handle_response_json_null(self): + req = httpx.Request("GET", "http://example.com/test") + resp = httpx.Response( + 200, request=req, text="null", headers={"content-type": "application/json"} + ) + client = self._make_client() + assert client._handle_response(resp) == {} + + def test_handle_response_invalid_json(self): + req = httpx.Request("GET", "http://example.com/test") + resp = httpx.Response( + 200, request=req, text="not json", headers={"content-type": "application/json"} + ) + client = self._make_client() + with pytest.raises(ValueError, match="Invalid JSON"): + client._handle_response(resp) + + def test_handle_response_http_error(self): + req = httpx.Request("GET", "http://example.com/test") + resp = httpx.Response(500, request=req, text="server error") + client = self._make_client() + with pytest.raises(httpx.HTTPStatusError): + client._handle_response(resp) + + +class TestAuthenticatedClient: + def _make_client(self, auth=None, default_user=None) -> AuthenticatedClient: + return AuthenticatedClient(ClientSettings(), default_user=default_user, auth=auth) + + def test_default_auth_created(self): + client = self._make_client() + assert isinstance(client._auth, Auth) + assert client._auth.is_expired() is True + + def test_custom_auth_used(self): + auth = Auth(token="pre-set") + client = self._make_client(auth=auth) + assert client._auth is auth + assert client._auth.token == "pre-set" + + def test_none_auth_creates_default(self): + client = self._make_client(auth=None) + assert isinstance(client._auth, Auth) + + def test_get_request_headers_no_token(self): + client = self._make_client() + headers = client._get_request_headers() + assert "Authorization" not in headers + assert headers["Content-Type"] == "application/json" + + def test_get_auth_headers_with_token(self): + class FakeUser: + public_key = "user123" + + auth = Auth(token="my-token") + client = self._make_client(auth=auth, default_user=FakeUser()) # type: ignore + headers = client._get_auth_headers() + assert headers["Authorization"] == "Bearer my-token" + assert headers["User-Id"] == "user123" + + def test_get_auth_headers_no_token(self): + client = self._make_client() + assert client._get_auth_headers() == {} + + def test_from_client_preserves_type(self): + parent = self._make_client() + child = AuthenticatedClient.from_client(parent) + assert isinstance(child, AuthenticatedClient) + + +class TestClientSettings: + def test_defaults(self): + settings = ClientSettings() + assert settings.base_url == "http://localhost:3032" + assert settings.timeout == 10.0 + assert settings.insecure_ssl is False + assert settings.verify_requests is True + + def test_insecure_ssl(self): + settings = ClientSettings(insecure_ssl=True) + assert settings.verify_requests is False + + def test_parsed_base_url(self): + settings = ClientSettings(base_url="https://api.example.com:8080") + parsed = settings.parsed_base_url + assert parsed.scheme == "https" + assert parsed.hostname == "api.example.com" + assert parsed.port == 8080 + + def test_custom_headers(self): + settings = ClientSettings(headers={"X-Custom": "value"}) + assert settings.headers == {"X-Custom": "value"} + + def test_from_url(self): + settings = ClientSettings.from_url("http://example.com", insecure_ssl=True) + assert settings.base_url == "http://example.com" + assert settings.insecure_ssl is True diff --git a/tests/client/test_orderbook_control_ws.py b/tests/client/test_orderbook_control_ws.py index 0e9632a..3fcf609 100644 --- a/tests/client/test_orderbook_control_ws.py +++ b/tests/client/test_orderbook_control_ws.py @@ -47,7 +47,10 @@ async def _ensure_control_ws(self) -> None: class DummyUser: public_key = "USER" - client = DummyClient(user=DummyUser(), base_url="http://example.com") # type: ignore + client = DummyClient( + "http://example.com", + default_user=DummyUser(), # type: ignore + ) client._use_ws_control = True order_id = "abc" diff --git a/tests/client/test_orders_404.py b/tests/client/test_orders_404.py index 7b48664..140f966 100644 --- a/tests/client/test_orders_404.py +++ b/tests/client/test_orders_404.py @@ -16,7 +16,7 @@ async def _request(self, method, endpoint, json_data=None, params=None): class DummyUser: public_key = "USER" - client = DummyClient(user=DummyUser(), base_url="http://example.com") # type: ignore + client = DummyClient("http://example.com", default_user=DummyUser()) # type: ignore orders = await client.get_user_orders_for_book( asset_id=type("A", (), {"__str__": lambda self: "200"})() ) diff --git a/tests/client/test_tls.py b/tests/client/test_tls.py index 36133a1..ff1cffa 100644 --- a/tests/client/test_tls.py +++ b/tests/client/test_tls.py @@ -3,15 +3,11 @@ @pytest.mark.anyio async def test_http_client_verify_flag(monkeypatch): - from tplus.client.base import BaseClient + from tplus.client.base import BaseClient, ClientSettings - class DummyUser: - public_key = "USER" - - c = BaseClient(user=DummyUser(), base_url="http://localhost") # type: ignore + settings = ClientSettings(base_url="http://localhost") + c = BaseClient(settings) try: - # Insecure is False by default → default httpx AsyncClient verifies certs by default - # We cannot access private verify attribute reliably; ensure auth headers would carry token when set assert isinstance(c._client, type(c._client)) finally: await c.close() @@ -19,14 +15,11 @@ class DummyUser: @pytest.mark.anyio async def test_http_client_insecure_ssl_disables_verify(monkeypatch): - from tplus.client.base import BaseClient - - class DummyUser: - public_key = "USER" + from tplus.client.base import BaseClient, ClientSettings - c = BaseClient(user=DummyUser(), base_url="http://localhost", insecure_ssl=True) # type: ignore + settings = ClientSettings(base_url="http://localhost", insecure_ssl=True) + c = BaseClient(settings) try: - # Construction should succeed with insecure flag; httpx accepts verify=False - assert c._insecure_ssl is True + assert settings.insecure_ssl is True finally: await c.close() diff --git a/tests/integration/test_clearing_engine.py b/tests/integration/test_clearing_engine.py index de26469..d30392d 100644 --- a/tests/integration/test_clearing_engine.py +++ b/tests/integration/test_clearing_engine.py @@ -12,7 +12,7 @@ def user() -> User: @pytest.fixture(scope="module") def clearing_engine(user): - return ClearingEngineClient(user, "http://127.0.0.1:3032") + return ClearingEngineClient.from_local(user) @pytest.fixture(scope="module") diff --git a/tplus/client/auth.py b/tplus/client/auth.py new file mode 100644 index 0000000..d4e9efb --- /dev/null +++ b/tplus/client/auth.py @@ -0,0 +1,140 @@ +import asyncio +import time +from typing import TYPE_CHECKING, Any + +from tplus.client.base import BaseClient + +if TYPE_CHECKING: + from tplus.types import UserType + from tplus.utils.user import User + + +class Auth: + SAFETY_MARGIN_NS = 60 * 1_000_000_000 + + def __init__(self, token: str | None = None) -> None: + self.lock = asyncio.Lock() + self.token = token + self.expiry_ns = 0 + + def is_expired(self) -> bool: + if self.token and (time.time_ns() + self.SAFETY_MARGIN_NS) < self.expiry_ns: + return False + return True + + +class AuthenticatedClient(BaseClient): + """ + A BaseClient that adds token-based authentication. + """ + + def __init__(self, *args, auth: Auth | None = None, **kwargs): + super().__init__(*args, **kwargs) + self._auth = auth or Auth() + + async def _request( + self, + method: str, + endpoint: str, + json_data: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + ) -> dict[str, Any]: + relative_url = endpoint if endpoint.startswith("/") else f"/{endpoint}" + + if not relative_url.startswith("/nonce") and not relative_url.startswith("/auth"): + await self._ensure_auth() + + response = await self._send(method, relative_url, json_data=json_data, params=params) + + # If we receive an HTTP 401/403, the auth token may have expired. Refresh the + # credentials **once** and retry the request automatically. + if response.status_code in {401, 403} and not relative_url.startswith("/auth"): + self.logger.info( + "Received %s for %s – refreshing auth token and retrying once.", + response.status_code, + relative_url, + ) + await self._authenticate() + response = await self._send(method, relative_url, json_data=json_data, params=params) + + return self._handle_response(response) + + def _get_request_headers(self) -> dict[str, str]: + headers = super()._get_request_headers() + headers.update(self._get_auth_headers()) + return headers + + def _get_auth_headers(self, user: "UserType | None" = None) -> dict[str, str]: + if not self._auth.token: + return {} + + return { + "Authorization": f"Bearer {self._auth.token}", + "User-Id": self._validate_user_public_key(user=user), + } + + async def _ensure_auth(self, user: "User | None" = None) -> None: + if not self._auth.is_expired(): + return + + async with self._auth.lock: + if not self._auth.is_expired(): + return + + await self._authenticate(user=user) + + async def _authenticate(self, user: "User | None" = None) -> None: + user = user or self._validate_user() + nonce_endpoint = f"/nonce/{user.public_key}" + nonce_resp = await self._client.get(nonce_endpoint) # type: ignore + nonce_resp.raise_for_status() + nonce_data = nonce_resp.json() if hasattr(nonce_resp, "json") else nonce_resp + + # NOTE: nonce_value **must** be a `str` here. + nonce_value = f"{nonce_data['value']}" if isinstance(nonce_data, dict) else f"{nonce_data}" + + signature_bytes = user.sign(nonce_value) + signature_array = list(signature_bytes) + nonce_value_len = len(nonce_value) # type: ignore + + self.logger.debug(f"AUTH DEBUG: nonce={nonce_value} (len={nonce_value_len})") + self.logger.debug( + f"AUTH DEBUG: signature={signature_array[:8]}... (len={len(signature_array)})" + ) + + auth_payload = { + "user_id": user.public_key, + "nonce": nonce_value, + "signature": signature_array, + } + + token_resp = await self._client.post("/auth", json=auth_payload) # type: ignore + token_resp.raise_for_status() + token_json = token_resp.json() if hasattr(token_resp, "json") else token_resp + + token = token_json.get("token") # type: ignore + expiry_ns = int(token_json["expiry_ns"]) # type: ignore + + # Mask token if present + if isinstance(token, str): + masked = token[:4] + "…" + token[-4:] if len(token) >= 8 else "***" + else: + masked = "***" + self.logger.debug(f"AUTH DEBUG: token={masked} expires={expiry_ns}") + + self._auth.token = token_json["token"] # type: ignore + self._auth.expiry_ns = expiry_ns + + async def _ws_auth_headers(self) -> dict[str, str]: + await self._ensure_auth() + return self._get_auth_headers() + + async def _open_ws( + self, + path: str, + ws_kwargs: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, + ): + auth_headers = await self._ws_auth_headers() + extra_headers = {**auth_headers, **(extra_headers or {})} + return await super()._open_ws(path, ws_kwargs=ws_kwargs, extra_headers=extra_headers) diff --git a/tplus/client/base.py b/tplus/client/base.py index ab2bb85..5fe118d 100644 --- a/tplus/client/base.py +++ b/tplus/client/base.py @@ -2,61 +2,127 @@ import logging import ssl from collections.abc import AsyncIterator, Callable -from typing import Any +from functools import cached_property +from typing import TYPE_CHECKING, Any, Self from urllib.parse import urlparse import httpx import websockets +from pydantic import BaseModel +from tplus.exceptions import MissingClientUserError from tplus.logger import get_logger from tplus.utils.user import User +if TYPE_CHECKING: + from tplus.model.types import UserPublicKey + from tplus.types import UserType + +DEFAULT_TIMEOUT = 10.0 +DEFAULT_HEADERS = {"Content-Type": "application/json", "Accept": "application/json"} + + +class ClientSettings(BaseModel): + """ + Validated client settings. + """ + + base_url: str = "http://localhost:3032" + """ + Base URL for requests. + """ + + timeout: float = DEFAULT_TIMEOUT + """ + Requests timeout. + """ + + websocket_kwargs: dict[str, Any] = {} + """ + Additional kwargs to pass to websocket requests. + """ + + insecure_ssl: bool = False + """ + Set to to not verify SSL certificates. + """ + + headers: dict[str, Any] = DEFAULT_HEADERS + """ + HTTP headers. + """ + + @classmethod + def from_url(cls, url: str, **kwargs) -> "ClientSettings": + return cls(base_url=url, **kwargs) + + @cached_property + def parsed_base_url(self): + return urlparse(self.base_url) + + @property + def verify_requests(self) -> bool: + return not self.insecure_ssl + + +def create_httpx_client(settings: ClientSettings) -> httpx.AsyncClient: + return httpx.AsyncClient( + base_url=settings.base_url, + timeout=settings.timeout, + headers=settings.headers, + verify=settings.verify_requests, + ) + class BaseClient: """ Base client to use across T+ services. """ - DEFAULT_TIMEOUT = 10.0 - AUTH = True - def __init__( self, - user: User, - base_url: str, - timeout: float = DEFAULT_TIMEOUT, - client: httpx.AsyncClient | None = None, - websocket_kwargs: dict[str, Any] | None = None, + settings: ClientSettings | str, + default_user: User | None = None, log_level: int = logging.INFO, - insecure_ssl: bool = False, + client: httpx.AsyncClient | None = None, + **kwargs, ): - self.user = user - self.base_url = base_url.rstrip("/") - self._parsed_base_url = urlparse(self.base_url) - if not isinstance(insecure_ssl, bool): - raise TypeError("insecure_ssl must be a bool") - self._insecure_ssl: bool = insecure_ssl - self._client = client or httpx.AsyncClient( - base_url=self.base_url, - timeout=timeout, - headers={"Content-Type": "application/json", "Accept": "application/json"}, - verify=not self._insecure_ssl, - ) - self._ws_kwargs: dict[str, Any] = websocket_kwargs or {} - - import asyncio + if isinstance(settings, str): + settings = ClientSettings.from_url(settings) - self._auth_lock: asyncio.Lock = asyncio.Lock() - self._auth_token: str | None = None - self._auth_expiry_ns: int = 0 + self._settings = settings + self._default_user = default_user + self._client = client or create_httpx_client(settings) self.logger = get_logger(log_level=log_level) @classmethod - def from_client(cls, client: "BaseClient"): - """ - Easy way to clone clients without initializing multiple AsyncClients. - """ - return cls(client.user, client.base_url, client=client._client) + def from_client(cls, client: "BaseClient") -> Self: + return cls( + client._settings, + default_user=client._default_user, + client=client._client, + ) + + def _validate_user(self, user: User | None = None) -> User: + if user is not None: + return user + + elif self._default_user is None: + raise MissingClientUserError() + + return self._default_user + + def _validate_user_public_key(self, user: "UserType | None" = None) -> "UserPublicKey": + if user is not None: + if isinstance(user, User): + return user.public_key + + return user + + elif self._default_user is None: + raise MissingClientUserError() + + return self._default_user.public_key async def _get(self, endpoint: str, json_data: dict[str, Any] | None = None) -> dict[str, Any]: return await self._request("GET", endpoint, json_data=json_data) @@ -72,175 +138,76 @@ async def _request( params: dict[str, Any] | None = None, ) -> dict[str, Any]: relative_url = endpoint if endpoint.startswith("/") else f"/{endpoint}" + response = await self._send(method, relative_url, json_data=json_data, params=params) + return self._handle_response(response) - if self.AUTH and ( - not relative_url.startswith("/nonce") and not relative_url.startswith("/auth") - ): - await self._ensure_auth() + def _get_request_headers(self) -> dict[str, str]: + return dict(self._settings.headers) - try: - if json_data or params: - self.logger.debug( - f"Request to {method} {relative_url} with payload: {json_data} params: {params}" - ) - request_headers = self._get_auth_headers() - if request_headers: - merged_headers = {**self._client.headers, **request_headers} - else: - merged_headers = None + async def _send( + self, + method: str, + relative_url: str, + json_data: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + ) -> httpx.Response: + if json_data or params: + self.logger.debug( + f"Request to {method} {relative_url} with payload: {json_data} params: {params}" + ) - response = await self._client.request( # type: ignore + merged_headers = self._get_request_headers() + + try: + return await self._client.request( # type: ignore method=method, url=relative_url, json=json_data, params=params, headers=merged_headers, ) + except httpx.TimeoutException as err: + self.logger.error(f"Request timed out to {err.request.url!r}: {err}") + raise - # If we receive an HTTP 401/403, the auth token may have expired. Refresh the - # credentials **once** and retry the request automatically. This keeps the - # higher-level client APIs unaware of token lifetimes and greatly simplifies - # consumer code. - if ( - self.AUTH - and response.status_code in {401, 403} - and not relative_url.startswith("/auth") - ): - self.logger.info( - "Received %s for %s – refreshing auth token and retrying once.", - response.status_code, - relative_url, - ) + except httpx.RequestError as err: + self.logger.error( + f"An error occurred while requesting {err.request.url!r}: {type(err).__name__} - {err}" + ) + raise - # Force re-authentication and rebuild the auth headers (inside the same - # lock to avoid a thundering herd when many coroutines hit expiry at the - # same time). - await self._authenticate() - retry_headers = {**self._client.headers, **self._get_auth_headers()} - - response = await self._client.request( # type: ignore - method=method, - url=relative_url, - json=json_data, - params=params, - headers=retry_headers, - ) + def _handle_response(self, response: httpx.Response) -> dict[str, Any]: + if response.status_code == 204: + return {} - if response.status_code == 204: - return {} + raise_for_status_with_body(response) - raise_for_status_with_body(response) + if not response.content: + return {} - if not response.content: + try: + json_response = response.json() + if json_response is None: + self.logger.warning( + f"API endpoint {response.request.url!r} returned JSON null. Treating as empty dictionary." + ) return {} - try: - json_response = response.json() - if json_response is None: - self.logger.warning( - f"API endpoint {response.request.url!r} returned JSON null. Treating as empty dictionary." - ) - return {} - - return json_response - - except Exception: - raise Exception( - f"Invalid response from server - status_code={response.status_code}." - ) + return json_response - except httpx.TimeoutException as e: - self.logger.error(f"Request timed out to {e.request.url!r}: {e}") - raise - except httpx.RequestError as e: - self.logger.error( - f"An error occurred while requesting {e.request.url!r}: {type(e).__name__} - {e}" - ) - raise - except httpx.HTTPStatusError as e: - self.logger.error( - f"HTTP error {e.response.status_code} while requesting {e.request.url!r}: {e.response.text}" - ) - raise except json.JSONDecodeError as e: self.logger.error( - f"Failed to decode JSON response from {response.request.url!r}. Status: {response.status_code}. Content: {response.text[:100]}..." + f"Failed to decode JSON response from {response.request.url!r}. " + f"Status: {response.status_code}. Content: {response.text[:100]}..." ) raise ValueError(f"Invalid JSON received from API: {e}") from e - async def _ensure_auth(self) -> None: - import time - - safety_margin_ns = 60 * 1_000_000_000 - if self._auth_token and (time.time_ns() + safety_margin_ns) < self._auth_expiry_ns: - return - - async with self._auth_lock: - if self._auth_token and (time.time_ns() + safety_margin_ns) < self._auth_expiry_ns: - return - - await self._authenticate() - - def _get_auth_headers(self) -> dict[str, str]: - if not self._auth_token: - return {} - return { - "Authorization": f"Bearer {self._auth_token}", - "User-Id": self.user.public_key, - } - - async def _authenticate(self) -> None: - nonce_endpoint = f"/nonce/{self.user.public_key}" - nonce_resp = await self._client.get(nonce_endpoint) # type: ignore - nonce_resp.raise_for_status() - nonce_data = nonce_resp.json() if hasattr(nonce_resp, "json") else nonce_resp - - # NOTE: nonce_value **must** be a `str` here. - nonce_value = f"{nonce_data['value']}" if isinstance(nonce_data, dict) else f"{nonce_data}" - - signature_bytes = self.user.sign(nonce_value) - signature_array = list(signature_bytes) - nonce_value_len = len(nonce_value) # type: ignore - - self.logger.debug(f"AUTH DEBUG: nonce={nonce_value} (len={nonce_value_len})") - self.logger.debug( - f"AUTH DEBUG: signature={signature_array[:8]}... (len={len(signature_array)})" - ) - - auth_payload = { - "user_id": self.user.public_key, - "nonce": nonce_value, - "signature": signature_array, - } - - token_resp = await self._client.post("/auth", json=auth_payload) # type: ignore - token_resp.raise_for_status() - token_json = token_resp.json() if hasattr(token_resp, "json") else token_resp - - token = token_json.get("token") # type: ignore - expiry_ns = int(token_json["expiry_ns"]) # type: ignore - - # Mask token if present - if isinstance(token, str): - masked = token[:4] + "…" + token[-4:] if len(token) >= 8 else "***" - else: - masked = "***" - self.logger.debug(f"AUTH DEBUG: token={masked} expires={expiry_ns}") - - self._auth_token = token_json["token"] # type: ignore - self._auth_expiry_ns = expiry_ns - - async def _ws_auth_headers(self) -> dict[str, str]: - if self.AUTH: - await self._ensure_auth() - - return self._get_auth_headers() - def _get_websocket_url(self, path: str) -> str: from urllib.parse import urlunparse - scheme = "wss" if self._parsed_base_url.scheme == "https" else "ws" - netloc = self._parsed_base_url.netloc + parsed = self._settings.parsed_base_url + scheme = "wss" if parsed.scheme == "https" else "ws" + netloc = parsed.netloc ws_path = path if path.startswith("/") else f"/{path}" return urlunparse((scheme, netloc, ws_path, "", "", "")) @@ -248,38 +215,41 @@ async def _open_ws( self, path: str, ws_kwargs: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, ): """ Build a WebSocket connection context for the given path with proper - auth headers and TLS/handshake settings. Returns a websockets.connect - context manager which can be used with "async with". + TLS/handshake settings. Returns a websockets.connect context manager. """ ws_url = self._get_websocket_url(path) - auth_headers = await self._ws_auth_headers() - final_kwargs = dict(self._ws_kwargs) + headers = extra_headers or {} + + final_kwargs = dict(self._settings.websocket_kwargs) if ws_kwargs: final_kwargs.update(ws_kwargs) - # Merge extra headers with auth headers + # Merge extra headers with caller-provided headers if "extra_headers" in final_kwargs and final_kwargs["extra_headers"]: caller_headers = final_kwargs.pop("extra_headers") if isinstance(caller_headers, dict): - caller_headers.update(auth_headers) + caller_headers.update(headers) final_kwargs["extra_headers"] = caller_headers else: - final_kwargs["extra_headers"] = list(auth_headers.items()) + list(caller_headers) + final_kwargs["extra_headers"] = list(headers.items()) + list(caller_headers) else: - final_kwargs["extra_headers"] = auth_headers + final_kwargs["extra_headers"] = headers + + parsed = self._settings.parsed_base_url # Provide Origin to be proxy/gateway friendly - origin = f"{self._parsed_base_url.scheme}://{self._parsed_base_url.netloc}" + origin = f"{parsed.scheme}://{parsed.netloc}" if "origin" not in final_kwargs: final_kwargs["origin"] = origin # Build SSL context (secure by default). Only set ALPN/server_hostname for HTTPS. - if self._parsed_base_url.scheme == "https": + if parsed.scheme == "https": ssl_context = ssl.create_default_context() - if self._insecure_ssl: + if self._settings.insecure_ssl: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE try: @@ -287,7 +257,7 @@ async def _open_ws( except Exception: pass - server_hostname = self._parsed_base_url.hostname + server_hostname = parsed.hostname if server_hostname: return websockets.connect( ws_url, **final_kwargs, ssl=ssl_context, server_hostname=server_hostname diff --git a/tplus/client/clearingengine/__init__.py b/tplus/client/clearingengine/__init__.py index 86e4235..fc02385 100644 --- a/tplus/client/clearingengine/__init__.py +++ b/tplus/client/clearingengine/__init__.py @@ -1,7 +1,7 @@ from functools import cached_property -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING -from tplus.client.base import BaseClient +from tplus.client.base import ClientSettings from tplus.client.clearingengine.admin import AdminClient from tplus.client.clearingengine.assetregistry import AssetRegistryClient from tplus.client.clearingengine.base import BaseClearingEngineClient @@ -22,54 +22,55 @@ class ClearingEngineClient(BaseClearingEngineClient): """ @classmethod - def from_local(cls, user: "User", port: int = 3032): - return cls(user, base_url=f"http://127.0.0.1:{port}") + def from_local(cls, user: "User", port: int = 3032, insecure_ssl: bool = False): + settings = ClientSettings(base_url=f"http://127.0.0.1:{port}", insecure_ssl=insecure_ssl) + return cls(settings, default_user=user) @cached_property def settlements(self) -> SettlementClient: """ APIs related to settlements. """ - return SettlementClient.from_client(cast(BaseClient, self)) + return SettlementClient.from_client(self) @cached_property def assets(self) -> AssetRegistryClient: """ APIs related to registered assets. """ - return AssetRegistryClient.from_client(cast(BaseClient, self)) + return AssetRegistryClient.from_client(self) @cached_property def decimals(self) -> DecimalClient: """ APIs related to decimals. """ - return DecimalClient.from_client(cast(BaseClient, self)) + return DecimalClient.from_client(self) @cached_property def deposits(self) -> DepositClient: """ APIs related to deposits. """ - return DepositClient.from_client(cast(BaseClient, self)) + return DepositClient.from_client(self) @cached_property def withdrawals(self) -> WithdrawalClient: """ APIs related to withdrawals. """ - return WithdrawalClient.from_client(cast(BaseClient, self)) + return WithdrawalClient.from_client(self) @cached_property def vaults(self) -> VaultClient: """ APIs related to vaults. """ - return VaultClient.from_client(cast(BaseClient, self)) + return VaultClient.from_client(self) @cached_property def admin(self) -> AdminClient: """ APIs related to the admin clearing-engine. """ - return AdminClient.from_client(cast(BaseClient, self)) + return AdminClient.from_client(self) diff --git a/tplus/client/oms/oms_admin.py b/tplus/client/oms/oms_admin.py index 3665b81..28c7754 100644 --- a/tplus/client/oms/oms_admin.py +++ b/tplus/client/oms/oms_admin.py @@ -1,7 +1,7 @@ -from tplus.client.base import BaseClient +from tplus.client.auth import AuthenticatedClient -class OmsAdminClient(BaseClient): +class OmsAdminClient(AuthenticatedClient): async def set_settings( self, solvency_verifier: str, diff --git a/tplus/client/orderbook.py b/tplus/client/orderbook.py index c5c4ab6..1e29865 100644 --- a/tplus/client/orderbook.py +++ b/tplus/client/orderbook.py @@ -3,14 +3,13 @@ import base64 import contextlib import json -import logging import uuid from collections.abc import AsyncIterator, Callable from typing import TYPE_CHECKING, Any import httpx -from tplus.client.base import BaseClient +from tplus.client.auth import AuthenticatedClient from tplus.model.asset_identifier import AssetIdentifier from tplus.model.klines import KlineUpdate, parse_kline_update from tplus.model.limit_order import GTC, GTD, IOC @@ -45,6 +44,7 @@ UserSolvency, parse_user_solvency, ) +from tplus.types import UserType from tplus.utils.limit_order import ( create_limit_order_ob_request_payload, ) @@ -72,7 +72,7 @@ def compute_remaining(order: OrderResponse) -> int: return max(0, total_qty - confirmed - pending) -class OrderBookClient(BaseClient): +class OrderBookClient(AuthenticatedClient): """Client for HTTP + WebSocket interactions with the OMS. Extra keyword-arguments for the underlying ``websockets.connect`` call can @@ -82,21 +82,11 @@ class OrderBookClient(BaseClient): def __init__( self, - user: "User", - *, - base_url: str, - websocket_kwargs: dict[str, Any] | None = None, - log_level: int = logging.INFO, + *args, use_ws_control: bool = False, - insecure_ssl: bool = False, + **kwargs, ) -> None: - super().__init__( - user, - base_url=base_url, - websocket_kwargs=websocket_kwargs, - log_level=log_level, - insecure_ssl=insecure_ssl, - ) + super().__init__(*args, **kwargs) # Cache Market details per asset to avoid repeated GET /market calls self._market_cache: dict[str, Market] = {} # When True, create/replace/cancel are sent via WS /control instead of HTTP @@ -144,10 +134,12 @@ async def create_market_order( fill_or_kill: bool = False, asset_id: AssetIdentifier | None = None, target: TradeTarget | None = None, + user: "User | None" = None, ) -> OrderOperationResponse: """ Create a market order (async). Uses WS /control if enabled. """ + user = self._validate_user(user) # TODO: Fix the signature of this method so that `asset_id` is required. asset_id_unwrapped: AssetIdentifier = asset_id # type: ignore @@ -166,7 +158,7 @@ async def create_market_order( ob_request_payload = create_market_order_ob_request_payload( side=side, - signer=self.user, + signer=user, book_quantity_decimals=market.book_quantity_decimals, book_price_decimals=market.book_price_decimals, asset_identifier=asset_id_unwrapped, @@ -196,10 +188,12 @@ async def create_limit_order( time_in_force: GTC | GTD | IOC | None = None, asset_id: AssetIdentifier | None = None, target: TradeTarget | None = None, + user: "User | None" = None, ) -> OrderOperationResponse: """ Create a limit order (async). Uses WS /control if enabled. """ + user = self._validate_user(user) # TODO: Fix the signature if this method such that `asset_id` is required. asset_id_unwrapped: AssetIdentifier = asset_id # type: ignore @@ -209,7 +203,7 @@ async def create_limit_order( quantity=quantity, price=price, side=side, - signer=self.user, + signer=user, book_quantity_decimals=market.book_quantity_decimals, book_price_decimals=market.book_price_decimals, asset_identifier=asset_id_unwrapped, @@ -228,13 +222,14 @@ async def create_limit_order( return OrderOperationResponse.model_validate(resp) async def cancel_order( - self, order_id: str, asset_id: AssetIdentifier + self, order_id: str, asset_id: AssetIdentifier, user: "User | None" = None ) -> OrderOperationResponse: """ Cancel an order (async). Uses WS /control if enabled. """ + user = self._validate_user(user) signed_message = create_cancel_order_ob_request_payload( - order_id=order_id, asset_identifier=asset_id, signer=self.user + order_id=order_id, asset_identifier=asset_id, signer=user ) self.logger.debug(f"Sending Cancel Order Request: OrderID={order_id}, Asset={asset_id}") if self._use_ws_control: @@ -252,15 +247,17 @@ async def replace_order( asset_id: AssetIdentifier, new_quantity: int | None = None, new_price: int | None = None, + user: "User | None" = None, ) -> OrderOperationResponse: """ Replace an existing order with new parameters (async). Uses WS /control if enabled. """ + user = self._validate_user(user) market = await self.get_market(asset_id) signed_message = create_replace_order_ob_request_payload( original_order_id=original_order_id, asset_identifier=asset_id, - signer=self.user, + signer=user, new_price=new_price, new_quantity=new_quantity, book_price_decimals=market.book_price_decimals, @@ -347,30 +344,37 @@ async def get_klines( parsed_data = parse_kline_update(response_data) return parsed_data - async def get_user_trades(self) -> list[UserTrade]: + async def get_user_trades(self, user: UserType | None = None) -> list[UserTrade]: """ Get all trades for the authenticated user (async). """ - endpoint = f"/trades/user/{self.user.public_key}" - self.logger.debug(f"Getting Trades for user {self.user.public_key}") + public_key = self._validate_user_public_key(user) + endpoint = f"/trades/user/{public_key}" + self.logger.debug(f"Getting Trades for user {public_key}") response_data = await self._request("GET", endpoint) return self.parse_user_trades(response_data) # type: ignore - async def get_user_trades_for_asset(self, asset_id: AssetIdentifier) -> list[UserTrade]: + async def get_user_trades_for_asset( + self, asset_id: AssetIdentifier, user: UserType | None = None + ) -> list[UserTrade]: """ Get trades for a specific asset for the authenticated user (async). """ - endpoint = f"/trades/user/{self.user.public_key}/{asset_id}" - self.logger.debug(f"Getting Trades for user {self.user.public_key}, asset {asset_id}") + public_key = self._validate_user_public_key(user) + endpoint = f"/trades/user/{public_key}/{asset_id}" + self.logger.debug(f"Getting Trades for user {public_key}, asset {asset_id}") response_data = await self._request("GET", endpoint) return self.parse_user_trades(response_data) # type: ignore - async def get_user_orders(self) -> tuple[list[OrderResponse], dict[str, Any]]: + async def get_user_orders( + self, user: UserType | None = None + ) -> tuple[list[OrderResponse], dict[str, Any]]: """ Get all orders for the authenticated user (async). """ - endpoint = f"/orders/user/{self.user.public_key}" - self.logger.debug(f"Getting Orders for user {self.user.public_key}") + public_key = self._validate_user_public_key(user) + endpoint = f"/orders/user/{public_key}" + self.logger.debug(f"Getting Orders for user {public_key}") response_data = await self._request("GET", endpoint) if isinstance(response_data, dict) and "error" in response_data: @@ -394,12 +398,14 @@ async def get_user_orders_for_book( page: int | None = None, limit: int | None = None, open_only: bool | None = None, + user: UserType | None = None, ) -> list[OrderResponse]: """ Get orders for a specific asset for the authenticated user (async). Handles 404 with empty list as "no orders" gracefully. """ - endpoint = f"/orders/user/{self.user.public_key}/{asset_id}" + public_key = self._validate_user_public_key(user) + endpoint = f"/orders/user/{public_key}/{asset_id}" params_dict: dict[str, Any] | None = None if page is not None or limit is not None or open_only is not None: params_dict = { @@ -408,7 +414,7 @@ async def get_user_orders_for_book( } if open_only is not None: params_dict["open_only"] = bool(open_only) - self.logger.debug(f"Getting Orders for user {self.user.public_key}, asset {asset_id}") + self.logger.debug(f"Getting Orders for user {public_key}, asset {asset_id}") try: response_data = await self._request("GET", endpoint, params=params_dict) @@ -429,18 +435,18 @@ async def get_user_orders_for_book( content = e.response.json() if isinstance(content, list) and not content: self.logger.debug( - f"Received 404 with empty list for {endpoint} (User: {self.user.public_key}, Asset: {asset_id}). " + f"Received 404 with empty list for {endpoint} (User: {public_key}, Asset: {asset_id}). " f"This is expected if the user has no orders for this asset yet. Treating as success with no orders." ) return [] else: self.logger.warning( - f"Received 404 for {endpoint} (User: {self.user.public_key}, Asset: {asset_id}), " + f"Received 404 for {endpoint} (User: {public_key}, Asset: {asset_id}), " f"but response body was not an empty list as expected for 'no orders'. Body: {e.response.text[:200]}" ) except json.JSONDecodeError: self.logger.warning( - f"Received 404 for {endpoint} (User: {self.user.public_key}, Asset: {asset_id}), " + f"Received 404 for {endpoint} (User: {public_key}, Asset: {asset_id}), " f"but response body was not valid JSON. Body: {e.response.text[:200]}" ) raise e @@ -629,12 +635,13 @@ async def close(self) -> None: self._control_ws_task = None await super().close() - async def get_user_inventory(self) -> dict[str, Any]: + async def get_user_inventory(self, user: UserType | None = None) -> dict[str, Any]: """ Get inventory for the authenticated user (async). """ - endpoint = f"/inventory/user/{self.user.public_key}" - self.logger.debug(f"Getting Inventory for user {self.user.public_key}") + public_key = self._validate_user_public_key(user) + endpoint = f"/inventory/user/{public_key}" + self.logger.debug(f"Getting Inventory for user {public_key}") return await self._request("GET", endpoint) async def stream_orders(self) -> AsyncIterator[OrderEvent]: @@ -675,7 +682,7 @@ async def stream_klines(self, asset_id: AssetIdentifier) -> AsyncIterator[KlineU yield kline async def stream_user_trade_events( - self, user_id: str | None = None + self, user_id: str | None = None, user: UserType | None = None ) -> AsyncIterator[UserTrade]: """ Stream **all** trade events (``Pending``, ``Confirmed``, ``Rollbacked``) for a specific user. @@ -683,17 +690,18 @@ async def stream_user_trade_events( Args: user_id: Optional explicit user identifier. If not provided, the authenticated user's public key is used. + user: Optional User or public key to resolve the user identity. Yields: :class:`tplus.model.trades.UserTrade` objects with detailed order-side information. """ if user_id is None: - user_id = self.user.public_key + user_id = self._validate_user_public_key(user) path = f"/trades/user/events/{user_id}" async for trade in self._stream_ws(path, parse_single_user_trade): yield trade async def stream_user_finalized_trades( - self, user_id: str | None = None + self, user_id: str | None = None, user: UserType | None = None ) -> AsyncIterator[UserTrade]: """ Stream **finalized** (confirmed) trades for a specific user. @@ -701,11 +709,12 @@ async def stream_user_finalized_trades( Args: user_id: Optional explicit user identifier. If not provided, the authenticated user's public key is used. + user: Optional User or public key to resolve the user identity. Yields: :class:`tplus.model.trades.UserTrade` instances containing only confirmed trades. """ if user_id is None: - user_id = self.user.public_key + user_id = self._validate_user_public_key(user) path = f"/trades/user/{user_id}" async for trade in self._stream_ws(path, parse_single_user_trade): yield trade @@ -718,13 +727,14 @@ async def stream_user_trades(self, user_id: str | None = None) -> AsyncIterator[ async for trade in self.stream_user_finalized_trades(user_id=user_id): yield trade - async def get_user_solvency(self) -> UserSolvency: + async def get_user_solvency(self, user: UserType | None = None) -> UserSolvency: """ Get solvency for the authenticated user (async). """ - endpoint = f"/solvency/user/{self.user.public_key}" + public_key = self._validate_user_public_key(user) + endpoint = f"/solvency/user/{public_key}" - self.logger.debug(f"Getting Solvency for user {self.user.public_key}") + self.logger.debug(f"Getting Solvency for user {public_key}") response_data = await self._request("GET", endpoint) if not isinstance(response_data, dict): @@ -737,6 +747,7 @@ async def get_user_margin_info( self, sub_accounts: list[int] | None = None, include_positions: bool = False, + user: UserType | None = None, ) -> UserMarginInfo: """ Get detailed margin breakdown for the authenticated user (async). @@ -757,6 +768,7 @@ async def get_user_margin_info( If None or empty, returns info for all sub-accounts. include_positions: If True, includes per-position breakdown with size and notional value for each position. + user: Optional User or public key. Falls back to the default user. Returns: UserMarginInfo containing margin breakdown per sub-account. @@ -764,7 +776,8 @@ async def get_user_margin_info( Raises: Exception: If the API response is invalid. """ - endpoint = f"/margin/user/{self.user.public_key}" + public_key = self._validate_user_public_key(user) + endpoint = f"/margin/user/{public_key}" params: dict[str, Any] = {} if sub_accounts: @@ -773,7 +786,7 @@ async def get_user_margin_info( params["include_positions"] = include_positions self.logger.debug( - f"Getting Margin Info for user {self.user.public_key}, " + f"Getting Margin Info for user {public_key}, " f"sub_accounts={sub_accounts}, include_positions={include_positions}" ) response_data = await self._request("GET", endpoint, params=params if params else None) @@ -791,9 +804,16 @@ async def request_transfer_to_subaccount( transfer_asset: AssetIdentifier, transfer_amount: int, target_account_type: None = None, + user: "User | None" = None, ) -> dict[str, Any]: + user = self._validate_user(user) payload = self._build_transfer_to_subaccount( - source_index, target_index, transfer_asset, transfer_amount, target_account_type + source_index, + target_index, + transfer_asset, + transfer_amount, + target_account_type, + user=user, ) response_data = await self._send_transfer_request(payload) @@ -807,10 +827,17 @@ async def _send_transfer_request(self, payload): return response_data def _build_transfer_to_subaccount( - self, source_index, target_index, transfer_asset, transfer_amount, target_account_type=None + self, + source_index, + target_index, + transfer_asset, + transfer_amount, + target_account_type=None, + user: "User | None" = None, ): + user = self._validate_user(user) inner = { - "user": self.user.public_key, + "user": user.public_key, "source_index": source_index, "target_index": target_index, "transfer_asset": str(transfer_asset), @@ -819,7 +846,7 @@ def _build_transfer_to_subaccount( } self.logger.debug(f"Transfer request: {inner}") signing_payload = json.dumps(inner, separators=(",", ":")) - signature = list(self.user.sign(signing_payload)) + signature = list(user.sign(signing_payload)) payload = { "inner": inner, "signature": signature, @@ -827,23 +854,29 @@ def _build_transfer_to_subaccount( } return payload - async def request_close_position(self, account: int, transfer_asset: str) -> dict[str, Any]: - payload = self._build_close_position_request(account, transfer_asset) + async def request_close_position( + self, account: int, transfer_asset: str, user: "User | None" = None + ) -> dict[str, Any]: + user = self._validate_user(user) + payload = self._build_close_position_request(account, transfer_asset, user=user) response_data = await self._send_close_position_request(payload) return response_data - def _build_close_position_request(self, account: int, transfer_asset: str) -> dict: + def _build_close_position_request( + self, account: int, transfer_asset: str, user: "User | None" = None + ) -> dict: """Same signing rules as CE: compact JSON of inner, ed25519 over UTF-8 bytes.""" + user = self._validate_user(user) inner = { - "user": self.user.public_key, + "user": user.public_key, "account": account, "asset_identifier": transfer_asset, } self.logger.debug(f"Preparing close position request: {inner}") signing_payload = json.dumps(inner, separators=(",", ":")) - signature = list(self.user.sign(signing_payload)) + signature = list(user.sign(signing_payload)) payload = { "inner": inner, "signature": signature, diff --git a/tplus/evm/managers/chaindata.py b/tplus/evm/managers/chaindata.py index c8a8e75..dacda22 100644 --- a/tplus/evm/managers/chaindata.py +++ b/tplus/evm/managers/chaindata.py @@ -24,8 +24,8 @@ def __init__( chain_id: ChainID | None = None, ): self.default_user = default_user - self.ce: ClearingEngineClient = clearing_engine or ClearingEngineClient( - self.default_user, "http://127.0.0.1:3032" + self.ce: ClearingEngineClient = clearing_engine or ClearingEngineClient.from_local( + self.default_user ) self.chain_id = chain_id or ChainID.evm(self.chain_manager.chain_id) diff --git a/tplus/evm/managers/settle.py b/tplus/evm/managers/settle.py index 33201c8..a8c5ecc 100644 --- a/tplus/evm/managers/settle.py +++ b/tplus/evm/managers/settle.py @@ -61,8 +61,8 @@ def __init__( ): self.default_user = default_user self.ape_account = ape_account - self.ce: ClearingEngineClient = clearing_engine or ClearingEngineClient( - self.default_user, "http://127.0.0.1:3032" + self.ce: ClearingEngineClient = clearing_engine or ClearingEngineClient.from_local( + self.default_user ) self.chain_id = chain_id or ChainID.evm(self.chain_manager.chain_id) self.vault = vault or DepositVault(chain_id=self.chain_id) diff --git a/tplus/exceptions.py b/tplus/exceptions.py new file mode 100644 index 0000000..a7b8623 --- /dev/null +++ b/tplus/exceptions.py @@ -0,0 +1,17 @@ +class BaseTplusException(Exception): + """ + Base exception class. + """ + + +class MissingClientUserError(BaseTplusException): + """ + Raised when a user is not specified in a client when required. + """ + + def __init__(self, context: str | None = None): + message = "User required. Create client with a default user or specify user in request." + if context is not None: + message = f"{message}. Context: {context}" + + super().__init__(message) diff --git a/tplus/types.py b/tplus/types.py new file mode 100644 index 0000000..3b6aa26 --- /dev/null +++ b/tplus/types.py @@ -0,0 +1,4 @@ +from tplus.model.types import UserPublicKey +from tplus.utils.user import User + +UserType = UserPublicKey | User From 1a8f5268a4909b6d37170cf4110e271f90ad45ef Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 7 Apr 2026 20:21:52 -0500 Subject: [PATCH 2/2] fmt py310 --- tplus/client/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tplus/client/base.py b/tplus/client/base.py index 5fe118d..ef79441 100644 --- a/tplus/client/base.py +++ b/tplus/client/base.py @@ -3,12 +3,13 @@ import ssl from collections.abc import AsyncIterator, Callable from functools import cached_property -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any from urllib.parse import urlparse import httpx import websockets from pydantic import BaseModel +from typing_extensions import Self from tplus.exceptions import MissingClientUserError from tplus.logger import get_logger