from __future__ import annotations

import base64
import binascii
import dataclasses
import datetime as dt
import hashlib
import json
import typing
import urllib.parse
from collections.abc import Iterable

from yt_dlp.extractor.youtube.pot._provider import (
    BuiltinIEContentProvider,
    IEContentProvider,
    IEContentProviderLogger,
)
from yt_dlp.extractor.youtube.pot._registry import (
    _pot_cache_provider_preferences,
    _pot_cache_providers,
    _pot_pcs_providers,
    _pot_providers,
    _ptp_preferences,
)
from yt_dlp.extractor.youtube.pot.cache import (
    CacheProviderWritePolicy,
    PoTokenCacheProvider,
    PoTokenCacheProviderError,
    PoTokenCacheSpec,
    PoTokenCacheSpecProvider,
)
from yt_dlp.extractor.youtube.pot.provider import (
    PoTokenProvider,
    PoTokenProviderError,
    PoTokenProviderRejectedRequest,
    PoTokenRequest,
    PoTokenResponse,
    provider_bug_report_message,
)
from yt_dlp.utils import bug_reports_message, format_field, join_nonempty

if typing.TYPE_CHECKING:
    from yt_dlp.extractor.youtube.pot.cache import CacheProviderPreference
    from yt_dlp.extractor.youtube.pot.provider import Preference


class YoutubeIEContentProviderLogger(IEContentProviderLogger):
    def __init__(self, ie, prefix, log_level: IEContentProviderLogger.LogLevel | None = None):
        self.__ie = ie
        self.prefix = prefix
        self.log_level = log_level if log_level is not None else self.LogLevel.INFO

    def _format_msg(self, message: str):
        prefixstr = format_field(self.prefix, None, '[%s] ')
        return f'{prefixstr}{message}'

    def trace(self, message: str):
        if self.log_level <= self.LogLevel.TRACE:
            self.__ie.write_debug(self._format_msg('TRACE: ' + message))

    def debug(self, message: str):
        if self.log_level <= self.LogLevel.DEBUG:
            self.__ie.write_debug(self._format_msg(message))

    def info(self, message: str):
        if self.log_level <= self.LogLevel.INFO:
            self.__ie.to_screen(self._format_msg(message))

    def warning(self, message: str, *, once=False):
        if self.log_level <= self.LogLevel.WARNING:
            self.__ie.report_warning(self._format_msg(message), only_once=once)

    def error(self, message: str):
        if self.log_level <= self.LogLevel.ERROR:
            self.__ie._downloader.report_error(self._format_msg(message), is_error=False)


class PoTokenCache:

    def __init__(
        self,
        logger: IEContentProviderLogger,
        cache_providers: list[PoTokenCacheProvider],
        cache_spec_providers: list[PoTokenCacheSpecProvider],
        cache_provider_preferences: list[CacheProviderPreference] | None = None,
    ):
        self.cache_providers: dict[str, PoTokenCacheProvider] = {
            provider.PROVIDER_KEY: provider for provider in (cache_providers or [])}
        self.cache_provider_preferences: list[CacheProviderPreference] = cache_provider_preferences or []
        self.cache_spec_providers: dict[str, PoTokenCacheSpecProvider] = {
            provider.PROVIDER_KEY: provider for provider in (cache_spec_providers or [])}
        self.logger = logger

    def _get_cache_providers(self, request: PoTokenRequest) -> Iterable[PoTokenCacheProvider]:
        """Sorts available cache providers by preference, given a request"""
        preferences = {
            provider: sum(pref(provider, request) for pref in self.cache_provider_preferences)
            for provider in self.cache_providers.values()
        }
        if self.logger.log_level <= self.logger.LogLevel.TRACE:
            # calling is_available() for every PO Token provider upfront may have some overhead
            self.logger.trace(f'PO Token Cache Providers: {provider_display_list(self.cache_providers.values())}')
            self.logger.trace('Cache Provider preferences for this request: {}'.format(', '.join(
                f'{provider.PROVIDER_KEY}={pref}' for provider, pref in preferences.items())))

        return (
            provider for provider in sorted(
                self.cache_providers.values(), key=preferences.get, reverse=True) if provider.is_available())

    def _get_cache_spec(self, request: PoTokenRequest) -> PoTokenCacheSpec | None:
        for provider in self.cache_spec_providers.values():
            if not provider.is_available():
                continue
            try:
                spec = provider.generate_cache_spec(request)
                if not spec:
                    continue
                if not validate_cache_spec(spec):
                    self.logger.error(
                        f'PoTokenCacheSpecProvider "{provider.PROVIDER_KEY}" generate_cache_spec() '
                        f'returned invalid spec {spec}{provider_bug_report_message(provider)}')
                    continue
                spec = dataclasses.replace(spec, _provider=provider)
                self.logger.trace(
                    f'Retrieved cache spec {spec} from cache spec provider "{provider.PROVIDER_NAME}"')
                return spec
            except Exception as e:
                self.logger.error(
                    f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache spec provider: '
                    f'{e!r}{provider_bug_report_message(provider)}')
                continue
        return None

    def _generate_key_bindings(self, spec: PoTokenCacheSpec) -> dict[str, str]:
        bindings_cleaned = {
            **{k: v for k, v in spec.key_bindings.items() if v is not None},
            # Allow us to invalidate caches if such need arises
            '_dlp_cache': 'v1',
        }
        if spec._provider:
            bindings_cleaned['_p'] = spec._provider.PROVIDER_KEY
        self.logger.trace(f'Generated cache key bindings: {bindings_cleaned}')
        return bindings_cleaned

    def _generate_key(self, bindings: dict) -> str:
        binding_string = ''.join(repr(dict(sorted(bindings.items()))))
        return hashlib.sha256(binding_string.encode()).hexdigest()

    def get(self, request: PoTokenRequest) -> PoTokenResponse | None:
        spec = self._get_cache_spec(request)
        if not spec:
            self.logger.trace('No cache spec available for this request, unable to fetch from cache')
            return None

        cache_key = self._generate_key(self._generate_key_bindings(spec))
        self.logger.trace(f'Attempting to access PO Token cache using key: {cache_key}')

        for idx, provider in enumerate(self._get_cache_providers(request)):
            try:
                self.logger.trace(
                    f'Attempting to fetch PO Token response from "{provider.PROVIDER_NAME}" cache provider')
                cache_response = provider.get(cache_key)
                if not cache_response:
                    continue
                try:
                    po_token_response = PoTokenResponse(**json.loads(cache_response))
                except (TypeError, ValueError, json.JSONDecodeError):
                    po_token_response = None
                if not validate_response(po_token_response):
                    self.logger.error(
                        f'Invalid PO Token response retrieved from cache provider "{provider.PROVIDER_NAME}": '
                        f'{cache_response}{provider_bug_report_message(provider)}')
                    provider.delete(cache_key)
                    continue
                self.logger.trace(
                    f'PO Token response retrieved from cache using "{provider.PROVIDER_NAME}" provider: '
                    f'{po_token_response}')
                if idx > 0:
                    # Write back to the highest priority cache provider,
                    # so we stop trying to fetch from lower priority providers
                    self.logger.trace('Writing PO Token response to highest priority cache provider')
                    self.store(request, po_token_response, write_policy=CacheProviderWritePolicy.WRITE_FIRST)

                return po_token_response
            except PoTokenCacheProviderError as e:
                self.logger.warning(
                    f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: '
                    f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}')
                continue
            except Exception as e:
                self.logger.error(
                    f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: '
                    f'{e!r}{provider_bug_report_message(provider)}',
                )
                continue
        return None

    def store(
        self,
        request: PoTokenRequest,
        response: PoTokenResponse,
        write_policy: CacheProviderWritePolicy | None = None,
    ):
        spec = self._get_cache_spec(request)
        if not spec:
            self.logger.trace('No cache spec available for this request. Not caching.')
            return

        if not validate_response(response):
            self.logger.error(
                f'Invalid PO Token response provided to PoTokenCache.store(): '
                f'{response}{bug_reports_message()}')
            return

        cache_key = self._generate_key(self._generate_key_bindings(spec))
        self.logger.trace(f'Attempting to access PO Token cache using key: {cache_key}')

        default_expires_at = int(dt.datetime.now(dt.timezone.utc).timestamp()) + spec.default_ttl
        cache_response = dataclasses.replace(response, expires_at=response.expires_at or default_expires_at)

        write_policy = write_policy or spec.write_policy
        self.logger.trace(f'Using write policy: {write_policy}')

        for idx, provider in enumerate(self._get_cache_providers(request)):
            try:
                self.logger.trace(
                    f'Caching PO Token response in "{provider.PROVIDER_NAME}" cache provider '
                    f'(key={cache_key}, expires_at={cache_response.expires_at})')
                provider.store(
                    key=cache_key,
                    value=json.dumps(dataclasses.asdict(cache_response)),
                    expires_at=cache_response.expires_at)
            except PoTokenCacheProviderError as e:
                self.logger.warning(
                    f'Error from "{provider.PROVIDER_NAME}" PO Token cache provider: '
                    f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}')
            except Exception as e:
                self.logger.error(
                    f'Error occurred with "{provider.PROVIDER_NAME}" PO Token cache provider: '
                    f'{e!r}{provider_bug_report_message(provider)}')

            # WRITE_FIRST should not write to lower priority providers in the case the highest priority provider fails
            if idx == 0 and write_policy == CacheProviderWritePolicy.WRITE_FIRST:
                return

    def close(self):
        for provider in self.cache_providers.values():
            provider.close()
        for spec_provider in self.cache_spec_providers.values():
            spec_provider.close()


class PoTokenRequestDirector:

    def __init__(self, logger: IEContentProviderLogger, cache: PoTokenCache):
        self.providers: dict[str, PoTokenProvider] = {}
        self.preferences: list[Preference] = []
        self.cache = cache
        self.logger = logger

    def register_provider(self, provider: PoTokenProvider):
        self.providers[provider.PROVIDER_KEY] = provider

    def register_preference(self, preference: Preference):
        self.preferences.append(preference)

    def _get_providers(self, request: PoTokenRequest) -> Iterable[PoTokenProvider]:
        """Sorts available providers by preference, given a request"""
        preferences = {
            provider: sum(pref(provider, request) for pref in self.preferences)
            for provider in self.providers.values()
        }
        if self.logger.log_level <= self.logger.LogLevel.TRACE:
            # calling is_available() for every PO Token provider upfront may have some overhead
            self.logger.trace(f'PO Token Providers: {provider_display_list(self.providers.values())}')
            self.logger.trace('Provider preferences for this request: {}'.format(', '.join(
                f'{provider.PROVIDER_NAME}={pref}' for provider, pref in preferences.items())))

        return (
            provider for provider in sorted(
                self.providers.values(), key=preferences.get, reverse=True)
            if provider.is_available()
        )

    def _get_po_token(self, request) -> PoTokenResponse | None:
        for provider in self._get_providers(request):
            try:
                self.logger.trace(
                    f'Attempting to fetch a PO Token from "{provider.PROVIDER_NAME}" provider')
                response = provider.request_pot(request.copy())
            except PoTokenProviderRejectedRequest as e:
                self.logger.trace(
                    f'PO Token Provider "{provider.PROVIDER_NAME}" rejected this request, '
                    f'trying next available provider. Reason: {e}')
                continue
            except PoTokenProviderError as e:
                self.logger.warning(
                    f'Error fetching PO Token from "{provider.PROVIDER_NAME}" provider: '
                    f'{e!r}{provider_bug_report_message(provider) if not e.expected else ""}')
                continue
            except Exception as e:
                self.logger.error(
                    f'Unexpected error when fetching PO Token from "{provider.PROVIDER_NAME}" provider: '
                    f'{e!r}{provider_bug_report_message(provider)}')
                continue

            self.logger.trace(f'PO Token response from "{provider.PROVIDER_NAME}" provider: {response}')

            if not validate_response(response):
                self.logger.error(
                    f'Invalid PO Token response received from "{provider.PROVIDER_NAME}" provider: '
                    f'{response}{provider_bug_report_message(provider)}')
                continue

            return response

        self.logger.trace('No PO Token providers were able to provide a valid PO Token')
        return None

    def get_po_token(self, request: PoTokenRequest) -> str | None:
        if not request.bypass_cache:
            if pot_response := self.cache.get(request):
                return clean_pot(pot_response.po_token)

        if not self.providers:
            self.logger.trace('No PO Token providers registered')
            return None

        pot_response = self._get_po_token(request)
        if not pot_response:
            return None

        pot_response.po_token = clean_pot(pot_response.po_token)

        if pot_response.expires_at is None or pot_response.expires_at > 0:
            self.cache.store(request, pot_response)
        else:
            self.logger.trace(
                f'PO Token response will not be cached (expires_at={pot_response.expires_at})')

        return pot_response.po_token

    def close(self):
        for provider in self.providers.values():
            provider.close()
        self.cache.close()


EXTRACTOR_ARG_PREFIX = 'youtubepot'


def initialize_pot_director(ie):
    assert ie._downloader is not None, 'Downloader not set'

    enable_trace = ie._configuration_arg(
        'pot_trace', ['false'], ie_key='youtube', casesense=False)[0] == 'true'

    if enable_trace:
        log_level = IEContentProviderLogger.LogLevel.TRACE
    elif ie.get_param('verbose', False):
        log_level = IEContentProviderLogger.LogLevel.DEBUG
    else:
        log_level = IEContentProviderLogger.LogLevel.INFO

    def get_provider_logger_and_settings(provider, logger_key):
        logger_prefix = f'{logger_key}:{provider.PROVIDER_NAME}'
        extractor_key = f'{EXTRACTOR_ARG_PREFIX}-{provider.PROVIDER_KEY.lower()}'
        return (
            YoutubeIEContentProviderLogger(ie, logger_prefix, log_level=log_level),
            ie.get_param('extractor_args', {}).get(extractor_key, {}))

    cache_providers = []
    for cache_provider in _pot_cache_providers.value.values():
        logger, settings = get_provider_logger_and_settings(cache_provider, 'pot:cache')
        cache_providers.append(cache_provider(ie, logger, settings))
    cache_spec_providers = []
    for cache_spec_provider in _pot_pcs_providers.value.values():
        logger, settings = get_provider_logger_and_settings(cache_spec_provider, 'pot:cache:spec')
        cache_spec_providers.append(cache_spec_provider(ie, logger, settings))

    cache = PoTokenCache(
        logger=YoutubeIEContentProviderLogger(ie, 'pot:cache', log_level=log_level),
        cache_providers=cache_providers,
        cache_spec_providers=cache_spec_providers,
        cache_provider_preferences=list(_pot_cache_provider_preferences.value),
    )

    director = PoTokenRequestDirector(
        logger=YoutubeIEContentProviderLogger(ie, 'pot', log_level=log_level),
        cache=cache,
    )

    ie._downloader.add_close_hook(director.close)

    for provider in _pot_providers.value.values():
        logger, settings = get_provider_logger_and_settings(provider, 'pot')
        director.register_provider(provider(ie, logger, settings))

    for preference in _ptp_preferences.value:
        director.register_preference(preference)

    if director.logger.log_level <= director.logger.LogLevel.DEBUG:
        # calling is_available() for every PO Token provider upfront may have some overhead
        director.logger.debug(f'PO Token Providers: {provider_display_list(director.providers.values())}')
        director.logger.debug(f'PO Token Cache Providers: {provider_display_list(cache.cache_providers.values())}')
        director.logger.debug(f'PO Token Cache Spec Providers: {provider_display_list(cache.cache_spec_providers.values())}')
        director.logger.trace(f'Registered {len(director.preferences)} provider preferences')
        director.logger.trace(f'Registered {len(cache.cache_provider_preferences)} cache provider preferences')

    return director


def provider_display_list(providers: Iterable[IEContentProvider]):
    def provider_display_name(provider):
        display_str = join_nonempty(
            provider.PROVIDER_NAME,
            provider.PROVIDER_VERSION if not isinstance(provider, BuiltinIEContentProvider) else None)
        statuses = []
        if not isinstance(provider, BuiltinIEContentProvider):
            statuses.append('external')
        if not provider.is_available():
            statuses.append('unavailable')
        if statuses:
            display_str += f' ({", ".join(statuses)})'
        return display_str

    return ', '.join(provider_display_name(provider) for provider in providers) or 'none'


def clean_pot(po_token: str):
    # Clean and validate the PO Token. This will strip invalid characters off
    # (e.g. additional url params the user may accidentally include)
    try:
        return base64.urlsafe_b64encode(
            base64.urlsafe_b64decode(urllib.parse.unquote(po_token))).decode()
    except (binascii.Error, ValueError):
        raise ValueError('Invalid PO Token')


def validate_response(response: PoTokenResponse | None):
    if (
        not isinstance(response, PoTokenResponse)
        or not isinstance(response.po_token, str)
        or not response.po_token
    ):  # noqa: SIM103
        return False

    try:
        clean_pot(response.po_token)
    except ValueError:
        return False

    if not isinstance(response.expires_at, int):
        return response.expires_at is None

    return response.expires_at <= 0 or response.expires_at > int(dt.datetime.now(dt.timezone.utc).timestamp())


def validate_cache_spec(spec: PoTokenCacheSpec):
    return (
        isinstance(spec, PoTokenCacheSpec)
        and isinstance(spec.write_policy, CacheProviderWritePolicy)
        and isinstance(spec.default_ttl, int)
        and isinstance(spec.key_bindings, dict)
        and all(isinstance(k, str) for k in spec.key_bindings)
        and all(v is None or isinstance(v, str) for v in spec.key_bindings.values())
        and bool([v for v in spec.key_bindings.values() if v is not None])
    )
