Source code for molten.http.headers
# This file is a part of molten.
#
# Copyright (C) 2018 CLEARTYPE SRL <[email protected]>
#
# molten is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at
# your option) any later version.
#
# molten is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from collections import defaultdict
from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union
from ..errors import HeaderMissing
from ..typing import Environ
#: An alias representing a dictionary of headers.
HeadersDict = Dict[str, Union[str, List[str]]]
#: WSGI keeps these separate from other headers.
CONTENT_VARS = {"CONTENT_LENGTH", "CONTENT_TYPE"}
[docs]class Headers(Iterable[Tuple[str, str]]):
"""A mapping from case-insensitive header names to lists of values.
"""
__slots__ = ["_headers"]
def __init__(self, mapping: Optional[HeadersDict] = None) -> None:
self._headers: Dict[str, List[str]] = defaultdict(list)
self.add_all(mapping or {})
[docs] @classmethod
def from_environ(cls, environ: Environ) -> "Headers":
"""Construct a Headers instance from a WSGI environ.
"""
headers = {}
for name, value in environ.items():
if name in CONTENT_VARS:
headers[name.replace("_", "-")] = value
elif name.startswith("HTTP_"):
headers[_parse_environ_header(name)] = [value]
return cls(headers)
[docs] def add(self, header: str, value: Union[str, List[str]]) -> None:
"""Add values for a particular header.
"""
if isinstance(value, list):
self._headers[header.lower()].extend(value)
else:
self._headers[header.lower()].append(value)
[docs] def add_all(self, mapping: HeadersDict) -> None:
"""Add a group of headers.
"""
for header, value_or_values in mapping.items():
self.add(header, value_or_values)
[docs] def get(self, header: str, default: Optional[str] = None) -> Optional[str]:
"""Get the last value for a given header.
"""
try:
return self[header]
except HeaderMissing:
return default
[docs] def get_all(self, header: str) -> List[str]:
"""Get all the values for a given header.
"""
return self._headers[header.lower()]
[docs] def get_int(self, header: str, default: Optional[int] = None) -> Optional[int]:
"""Get the last value for a given header as an integer.
"""
try:
return int(self[header])
except HeaderMissing:
return default
def __delitem__(self, header: str) -> None:
"""Delete all the values for a given header.
"""
del self._headers[header.lower()]
def __getitem__(self, header: str) -> str:
"""Get the last value for a given header.
Raises:
HeaderMissing: When the header is missing.
"""
try:
return self._headers[header.lower()][-1]
except IndexError:
raise HeaderMissing(header)
def __setitem__(self, header: str, value: str) -> None:
"""Replace a header's values.
"""
self._headers[header.lower()] = [value]
def __iter__(self) -> Iterator[Tuple[str, str]]:
"""Iterate over all the headers.
"""
for header, values in self._headers.items():
for value in values:
yield header, value
def __repr__(self) -> str:
mapping = ", ".join(f"{repr(name)}: {repr(value)}" for name, value in self._headers.items())
return f"Headers({{{mapping}}})"
#: The number of characters that are stripped from the beginning of
#: every header name in a WSGI environ.
HEADER_PREFIX_LEN = len("HTTP_")
#: A lookup table from WSGI header strings to header names.
HEADER_PARSER_CACHE: Dict[str, str] = {}
def _parse_environ_header(header: str) -> str:
try:
return HEADER_PARSER_CACHE[header]
except KeyError:
HEADER_PARSER_CACHE[header] = parsed_header = header[HEADER_PREFIX_LEN:].replace("_", "-")
return parsed_header