Source code for molten.testing.client

# This file is a part of molten.
#
# Copyright (C) 2018 CLEARTYPE SRL <[email protected]>
#
# molten is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at
# your option) any later version.
#
# molten is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import json
import os
import random
import string
from functools import partial
from io import BytesIO
from json import dumps as to_json
from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, Union, cast
from urllib.parse import urlencode

from ..app import BaseApp
from ..http import HTTP_200
from ..http.headers import Headers, HeadersDict
from ..http.query_params import ParamsDict, QueryParams
from ..http.request import Request
from ..http.response import Response
from ..typing import Environ
from .common import to_environ

try:
    from gunicorn.http.errors import NoMoreData  # type: ignore
except ImportError:  # pragma: no cover
    class NoMoreData(Exception):  # type: ignore
        pass


HTTP_METHODS = {"delete", "head", "get", "options", "patch", "post", "put"}


[docs]class TestResponse: """A wrapper around Response objects that adds a few additional helper methods for testing. Attributes: status: The response status line. status_code: The response status code as an integer. headers: The response headers. stream: The response data as a binary file. data: The response data as a string. """ __slots__ = ["_response"] def __init__(self, response: Response) -> None: self._response = response @property def data(self) -> str: """Rewinds the output stream and returns all its data. """ self._response.stream.seek(0) return self._response.stream.read().decode("utf-8") @property def status_code(self) -> int: """Returns the HTTP status code as an integer. """ code, _, _ = self._response.status.partition(" ") return int(code)
[docs] def json(self) -> Any: """Convert the response data to JSON. """ return json.loads(self.data)
def __getattr__(self, name: str) -> Any: return getattr(self._response, name)
FileSpec = Union[ Tuple[str, BinaryIO], Tuple[str, str, BinaryIO], ]
[docs]class TestClient: """Test clients are used to simulate requests against an application instance. """ __slots__ = ["app"] def __init__(self, app: BaseApp) -> None: self.app = app
[docs] def request( self, method: str, path: str, headers: Optional[Union[HeadersDict, Headers]] = None, params: Optional[Union[ParamsDict, QueryParams]] = None, body: Optional[bytes] = None, data: Optional[Dict[str, str]] = None, files: Optional[Dict[str, FileSpec]] = None, json: Optional[Any] = None, auth: Optional[Callable[[Request], Request]] = None, prepare_environ: Optional[Callable[[Environ], Environ]] = None, ) -> TestResponse: """Simulate a request against the application. Raises: RuntimeError: If both 'data' and 'json' are provided. Parameters: method: The request method. path: The request path. headers: Optional request headers. params: Optional query params. body: An optional bytestring for the request body. data: An optional dictionary for the request body that gets url-encoded. files: An optional dictionary of files to upload as part of a multipart request. json: An optional value for the request body that gets json-encoded. auth: An optional function that can be used to add auth headers to the request. """ if data is not None and json is not None or files is not None and json is not None: raise RuntimeError("either 'data'/'files' or 'json' should be provided, not both") request = Request( method=method.upper(), path=path, headers=headers, params=params, ) if body is not None: request.headers["content-length"] = f"{len(body)}" request.body_file = BytesIO(body) elif files is not None: boundary = "--" + "".join(random.choice(string.ascii_letters) for _ in range(30)) request.body_file = content = BytesIO() content.write(f"{boundary}\r\n".encode()) for field_name, file_spec in files.items(): file_type = "application/octet-stream" if len(file_spec) == 3: file_name, file_type, file_stream = file_spec # type: ignore else: file_name, file_stream = file_spec # type: ignore content.write(f"--{boundary}\r\n".encode()) content.write(f'content-disposition: multipart/form-data; name="{field_name}"; filename="{file_name}"\r\n'.encode()) content.write(f'content-type: {file_type}\r\n'.encode()) content.write(b"\r\n") content.write(file_stream.read()) content.write(b"\r\n") for field_name, field_value in (data or {}).items(): content.write(f"--{boundary}\r\n".encode()) content.write(f'content-disposition: multipart/form-data; name="{field_name}"\r\n'.encode()) content.write(b"\r\n") content.write(field_value.encode()) content.write(b"\r\n") content.write(f"--{boundary}--\r\n".encode()) request.headers["content-type"] = f"multipart/form-data; boundary={boundary}" request.headers["content-length"] = str(content.tell()) content.seek(0, os.SEEK_SET) elif data is not None: request_content = urlencode(data).encode("utf-8") request.headers["content-type"] = "application/x-www-form-urlencoded" request.headers["content-length"] = f"{len(request_content)}" request.body_file = BytesIO(request_content) elif json is not None: request_content = to_json(json).encode("utf-8") request.headers["content-type"] = "application/json; charset=utf-8" request.headers["content-length"] = f"{len(request_content)}" request.body_file = BytesIO(request_content) if auth is not None: request = auth(request) response = Response(HTTP_200) def start_response(status, response_headers, exc_info=None): # type: ignore nonlocal response response.status = status response.headers = Headers(dict(response_headers)) try: environ = to_environ(request) if prepare_environ: # pragma: no cover environ = prepare_environ(environ) chunks = self.app(environ, start_response) except NoMoreData: chunks = [] if response.headers.get("transfer-encoding") == "chunked": response.stream = cast(BinaryIO, chunks) return TestResponse(response) for chunk in chunks: response.stream.write(chunk) response.stream.seek(0) return TestResponse(response)
def __getattr__(self, name: str) -> Any: if name in HTTP_METHODS: return partial(self.request, name) raise AttributeError(f"unknown attribute {name}")