335 lines
10 KiB
Python
335 lines
10 KiB
Python
"""
|
|
The MIT License (MIT)
|
|
|
|
Copyright (c) 2015-present Rapptz
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a
|
|
copy of this software and associated documentation files (the "Software"),
|
|
to deal in the Software without restriction, including without limitation
|
|
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
and/or sell copies of the Software, and to permit persons to whom the
|
|
Software is furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in
|
|
all copies or substantial portions of the Software.
|
|
|
|
The Software is modified as follows:
|
|
- Delete unused functions and method.
|
|
- Removing functions beyond what is necessary to make it work.
|
|
- Simplification of some functions.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
|
DEALINGS IN THE SOFTWARE.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import importlib
|
|
import inspect
|
|
import logging
|
|
import re
|
|
import sys
|
|
import traceback
|
|
from typing import Any, Callable, Coroutine, Dict, Optional, Tuple, Union
|
|
|
|
from aiohttp import ClientWebSocketResponse
|
|
from mipac.client import Client as API
|
|
from mipac.manager.client import ClientManager
|
|
from mipac.models.user import MeDetailed
|
|
|
|
from mipa.exception import WebSocketNotConnected, WebSocketReconnect
|
|
from mipa.gateway import MisskeyWebSocket
|
|
from mipa.router import Router
|
|
from mipa.state import ConnectionState
|
|
from mipa.utils import LOGING_LEVEL_TYPE, setup_logging
|
|
|
|
_log = logging.getLogger()
|
|
|
|
|
|
class Client:
|
|
def __init__(
|
|
self,
|
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
|
max_capture: int = 100,
|
|
**options: Dict[Any, Any],
|
|
):
|
|
super().__init__(**options)
|
|
self.max_capture = max_capture
|
|
self._router: Router
|
|
self.url = None
|
|
self.extra_events: Dict[str, Any] = {}
|
|
self.special_events: Dict[str, Any] = {}
|
|
self.token: Optional[str] = None
|
|
self.origin_uri: Optional[str] = None
|
|
self.loop = asyncio.get_event_loop() if loop is None else loop
|
|
self.core: API
|
|
self._connection: ConnectionState
|
|
self.user: MeDetailed
|
|
self.ws: Optional[MisskeyWebSocket] = None
|
|
self.should_reconnect = True
|
|
|
|
def _get_state(self, **options: Any) -> ConnectionState:
|
|
return ConnectionState(
|
|
dispatch=self.dispatch, loop=self.loop, client=self
|
|
)
|
|
|
|
async def on_ready(self, ws: ClientWebSocketResponse):
|
|
"""
|
|
on_readyのデフォルト処理
|
|
|
|
Parameters
|
|
----------
|
|
ws : WebSocketClientProtocol
|
|
"""
|
|
|
|
def event(self, name: Optional[str] = None):
|
|
def decorator(func: Coroutine[Any, Any, Any]):
|
|
self.add_event(func, name)
|
|
return func
|
|
|
|
return decorator
|
|
|
|
def add_event(
|
|
self, func: Coroutine[Any, Any, Any], name: Optional[str] = None
|
|
):
|
|
name = func.__name__ if name is None else name
|
|
if not asyncio.iscoroutinefunction(func):
|
|
raise TypeError("Listeners must be coroutines")
|
|
|
|
if name in self.extra_events:
|
|
self.special_events[name].append(func)
|
|
else:
|
|
self.special_events[name] = [func]
|
|
|
|
def listen(self, name: Optional[str] = None):
|
|
def decorator(func: Coroutine[Any, Any, Any]):
|
|
self.add_listener(func, name)
|
|
return func
|
|
|
|
return decorator
|
|
|
|
def add_listener(
|
|
self,
|
|
func: Union[Coroutine[Any, Any, Any], Callable[..., Any]],
|
|
name: Optional[str] = None,
|
|
):
|
|
name = func.__name__ if name is None else name
|
|
if not asyncio.iscoroutinefunction(func):
|
|
raise TypeError("Listeners must be coroutines")
|
|
_log.debug(f"add_listener: {name} {func.__name__}")
|
|
if name in self.extra_events:
|
|
self.extra_events[name].append(func)
|
|
else:
|
|
self.extra_events[name] = [func]
|
|
|
|
def event_dispatch(
|
|
self, event_name: str, *args: Tuple[Any], **kwargs: Dict[Any, Any]
|
|
) -> bool:
|
|
"""
|
|
on_ready等といった
|
|
|
|
Parameters
|
|
----------
|
|
event_name :
|
|
args :
|
|
kwargs :
|
|
|
|
Returns
|
|
-------
|
|
|
|
"""
|
|
|
|
ev = f"on_{event_name}"
|
|
for event in self.special_events.get(ev, []):
|
|
foo = importlib.import_module(event.__module__)
|
|
coro = getattr(foo, ev)
|
|
self.schedule_event(coro, event, *args, **kwargs)
|
|
if ev in dir(self):
|
|
self.schedule_event(getattr(self, ev), ev, *args, **kwargs)
|
|
return ev in dir(self)
|
|
|
|
def dispatch(
|
|
self, event_name: str, *args: tuple[Any], **kwargs: Dict[Any, Any]
|
|
):
|
|
ev = f"on_{event_name}"
|
|
for event in self.extra_events.get(ev, []):
|
|
if inspect.ismethod(event):
|
|
coro = event
|
|
event = event.__name__
|
|
else:
|
|
foo = importlib.import_module(event.__module__)
|
|
coro = getattr(foo, ev)
|
|
self.schedule_event(coro, event, *args, **kwargs)
|
|
if ev in dir(self):
|
|
self.schedule_event(getattr(self, ev), ev, *args, **kwargs)
|
|
|
|
def schedule_event(
|
|
self,
|
|
coro: Callable[..., Coroutine[Any, Any, Any]],
|
|
event_name: str,
|
|
*args: tuple[Any],
|
|
**kwargs: Dict[Any, Any],
|
|
) -> asyncio.Task[Any]:
|
|
return self.loop.create_task(
|
|
self._run_event(coro, event_name, *args, **kwargs),
|
|
name=f"MiPA: {event_name}",
|
|
)
|
|
|
|
async def _run_event(
|
|
self,
|
|
coro: Callable[..., Coroutine[Any, Any, Any]],
|
|
event_name: str,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
try:
|
|
await coro(*args, **kwargs)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception:
|
|
try:
|
|
await self.__on_error(event_name)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
@staticmethod
|
|
async def __on_error(event_method: str) -> None:
|
|
print(f"Ignoring exception in {event_method}", file=sys.stderr)
|
|
traceback.print_exc()
|
|
|
|
async def on_error(self, err):
|
|
self.event_dispatch("error", err)
|
|
|
|
async def create_api_session(
|
|
self,
|
|
token: str,
|
|
url: str,
|
|
log_level: LOGING_LEVEL_TYPE | None,
|
|
) -> API:
|
|
self.core = API(url, token, log_level=log_level)
|
|
return self.core
|
|
|
|
async def setup_hook(self) -> None: ...
|
|
|
|
async def login(
|
|
self, token: str, url: str, log_level: LOGING_LEVEL_TYPE | None
|
|
):
|
|
"""
|
|
ユーザーにログインし、ユーザー情報を取得します
|
|
|
|
Parameters
|
|
----------
|
|
token : str
|
|
BOTにするユーザーのTOKEN
|
|
url : str
|
|
BOTにするユーザーがいるインスタンスのURL
|
|
log_level : LOGING_LEVEL_TYPE
|
|
The log level to use for logging. Defaults to ``INFO``.
|
|
"""
|
|
|
|
core = await self.create_api_session(token, url, log_level)
|
|
await core.http.login()
|
|
self.user = await core.api.get_me()
|
|
await self.setup_hook()
|
|
|
|
async def _connect(
|
|
self,
|
|
*,
|
|
timeout: int = 60,
|
|
event_name: str = "ready",
|
|
) -> None:
|
|
self._connection = self._get_state()
|
|
coro = MisskeyWebSocket.from_client(
|
|
self, timeout=timeout, event_name=event_name
|
|
)
|
|
self.ws = await asyncio.wait_for(coro, timeout=60)
|
|
while True:
|
|
await self.ws.poll_event()
|
|
|
|
async def connect(
|
|
self,
|
|
*,
|
|
reconnect: bool = True,
|
|
timeout: int = 60,
|
|
) -> None:
|
|
self.should_reconnect = reconnect
|
|
event_name = "ready"
|
|
while True:
|
|
try:
|
|
await self._connect(timeout=timeout, event_name=event_name)
|
|
except (WebSocketReconnect, asyncio.exceptions.TimeoutError):
|
|
if not self.should_reconnect:
|
|
break
|
|
event_name = "reconnect"
|
|
await asyncio.sleep(3)
|
|
|
|
async def disconnect(self):
|
|
if not self.ws:
|
|
raise WebSocketNotConnected()
|
|
self.should_reconnect = False
|
|
await self.ws.socket.close()
|
|
|
|
@property
|
|
def client(self) -> ClientManager:
|
|
return self.core.api
|
|
|
|
@property
|
|
def router(self) -> Router:
|
|
return self._router
|
|
|
|
async def start(
|
|
self,
|
|
url: str,
|
|
token: str,
|
|
*,
|
|
debug: bool = False,
|
|
reconnect: bool = True,
|
|
timeout: int = 60,
|
|
is_ayuskey: bool = False,
|
|
log_level: LOGING_LEVEL_TYPE | None = "INFO",
|
|
):
|
|
"""
|
|
Starting Bot
|
|
|
|
Parameters
|
|
----------
|
|
url: str
|
|
Misskey Instance Websocket URL (wss://example.com)
|
|
token: str
|
|
User Token
|
|
debug: bool, default False
|
|
debugging mode
|
|
reconnect: bool, default True
|
|
coming soon...
|
|
timeout: int, default 60
|
|
Time until websocket times out
|
|
"""
|
|
if log_level is not None:
|
|
setup_logging(level=log_level)
|
|
self.token = token
|
|
url = url[:-1] if url[-1] == "/" else url
|
|
split_url = url.split("/")
|
|
|
|
if origin_url := re.search(r"wss?://(.*)", url):
|
|
origin_url = (
|
|
origin_url.group(0)
|
|
.replace("wss", "https")
|
|
.replace("ws", "http")
|
|
.replace("/streaming", "")
|
|
)
|
|
else:
|
|
origin_url = url
|
|
if "streaming" not in split_url:
|
|
split_url.append("streaming")
|
|
url = "/".join(split_url)
|
|
self.url = url.replace("https", "wss").replace("http", "ws")
|
|
self.origin_url = origin_url
|
|
await self.login(token, origin_url, log_level)
|
|
await self.connect(reconnect=reconnect, timeout=timeout)
|