diff --git a/command.py b/command.py index 6131081..afbb242 100644 --- a/command.py +++ b/command.py @@ -1,245 +1,245 @@ from abc import ABC, abstractmethod from dataclasses import dataclass import datetime from typing import List, Optional, Tuple, Union from action import Action from page import Page from siteinfo import CategoryInfo -@dataclass +@dataclass(frozen=True) class Command: """A list of actions to perform on a page.""" page: Page actions: List['Action'] def apply(self, wikitext: str, category_info: CategoryInfo) -> Tuple[str, List[Tuple[Action, bool]]]: """Apply the actions of this command to the given wikitext and return the result as well as the actions together with the information whether they were a no-op or not. """ actions = [] for action in self.actions: new_wikitext = action.apply(wikitext, category_info) actions.append((action, wikitext == new_wikitext)) wikitext = new_wikitext return wikitext, actions def cleanup(self) -> None: """Partially normalize the command, as a convenience for users. This should not be used as a replacement for full normalization via the MediaWiki API. """ self.page.cleanup() for action in self.actions: action.cleanup() def actions_tpsv(self) -> str: return '|'.join([str(action) for action in self.actions]) def __str__(self) -> str: return str(self.page) + '|' + self.actions_tpsv() -@dataclass +@dataclass(frozen=True) class CommandRecord(ABC): """A command that was recorded in some store.""" id: int command: Command class CommandPlan(CommandRecord): """A command that should be run in the future.""" def __str__(self) -> str: return str(self.command) class CommandPending(CommandRecord): """A command that is about to be run or currently running.""" def __str__(self) -> str: return str(self.command) class CommandFinish(CommandRecord): """A command that was intended to be run at some point and should now no longer be run.""" def __str__(self) -> str: return '# ' + str(self.command) class CommandSuccess(CommandFinish): """A command that was successfully run.""" -@dataclass +@dataclass(frozen=True) class CommandEdit(CommandSuccess): """A command that resulted in an edit on a page.""" base_revision: int revision: int def __post_init__(self) -> None: assert self.base_revision < self.revision -@dataclass +@dataclass(frozen=True) class CommandNoop(CommandSuccess): """A command that resulted in no change to a page.""" revision: int class CommandFailure(CommandFinish): """A command that was not successfully run.""" @abstractmethod def can_retry_immediately(self) -> bool: """Whether it is okay to retry running this command immediately. In case of an immediate retry, no permanent record of the failure is kept, so this should not be used if the failure resulted in any actions on the wiki.""" @abstractmethod def can_retry_later(self) -> bool: """Whether it is okay to retry running this command at a later time. If True, a new command plan for the same command with a fresh ID will be appended to the end of the batch.""" @abstractmethod def can_continue_batch(self) -> Union[bool, datetime.datetime]: """Whether it is okay to continue with other commands in this batch. If the failure only affects this command, we can proceed with the batch as usual; if other commands are likely to fail for the same reason, or we should back off for some other reason, the batch should be suspended for a time. True means that we can continue immediately; False means that we should suspend the batch for an unspecified time (i. e., stop background runs and only proceed on manual input by the user); a datetime means that we should suspend the batch until that time (i. e., suspend background runs but resume them automatically).""" -@dataclass +@dataclass(frozen=True) class CommandPageMissing(CommandFailure): """A command that failed because the specified page was found to be missing at the time.""" curtimestamp: str def can_retry_immediately(self) -> bool: return False def can_retry_later(self) -> bool: return False def can_continue_batch(self) -> bool: return True -@dataclass +@dataclass(frozen=True) class CommandTitleInvalid(CommandFailure): """A command that failed because the specified title was invalid.""" curtimestamp: str def can_retry_immediately(self) -> bool: return False def can_retry_later(self) -> bool: return False def can_continue_batch(self) -> bool: return True -@dataclass +@dataclass(frozen=True) class CommandPageProtected(CommandFailure): """A command that failed because the specified page was protected at the time.""" curtimestamp: str def can_retry_immediately(self) -> bool: return False def can_retry_later(self) -> bool: return False def can_continue_batch(self) -> bool: return True class CommandEditConflict(CommandFailure): """A command that failed due to an edit conflict.""" def can_retry_immediately(self) -> bool: return True def can_retry_later(self) -> bool: return True def can_continue_batch(self) -> bool: return True -@dataclass +@dataclass(frozen=True) class CommandMaxlagExceeded(CommandFailure): """A command that failed because replication lag in the database cluster was too high.""" retry_after: datetime.datetime def can_retry_immediately(self) -> bool: return False def can_retry_later(self) -> bool: return True def can_continue_batch(self) -> datetime.datetime: return self.retry_after -@dataclass +@dataclass(frozen=True) class CommandBlocked(CommandFailure): """A command that failed because the user or IP address was blocked.""" auto: bool blockinfo: Optional[dict] def can_retry_immediately(self) -> bool: return False def can_retry_later(self) -> bool: return True def can_continue_batch(self) -> bool: # we could perhaps continue the batch if the block is partial # (that is, if self.blockinfo['blockpartial'] is True), # but for now I’d rather not return False -@dataclass +@dataclass(frozen=True) class CommandWikiReadOnly(CommandFailure): """A command that failed because the wiki was in read-only mode.""" reason: Optional[str] retry_after: Optional[datetime.datetime] def can_retry_immediately(self) -> bool: return False def can_retry_later(self) -> bool: return True def can_continue_batch(self) -> Union[bool, datetime.datetime]: return self.retry_after or False diff --git a/database.py b/database.py index abec8ce..9f6e475 100644 --- a/database.py +++ b/database.py @@ -1,697 +1,697 @@ import contextlib from dataclasses import dataclass import datetime import itertools import json import mwapi # type: ignore import mwoauth # type: ignore import pymysql import requests_oauthlib # type: ignore from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Type, cast from batch import NewBatch, StoredBatch, OpenBatch, ClosedBatch, BatchCommandRecords, BatchBackgroundRuns from command import Command, CommandPlan, CommandPending, CommandRecord, CommandFinish, CommandEdit, CommandNoop, CommandFailure, CommandPageMissing, CommandTitleInvalid, CommandPageProtected, CommandEditConflict, CommandMaxlagExceeded, CommandBlocked, CommandWikiReadOnly from localuser import LocalUser from page import Page import parse_tpsv from querytime import QueryTimingCursor, QueryTimingSSCursor from store import BatchStore, _local_user_from_session from stringstore import StringTableStore from timestamp import now, datetime_to_utc_timestamp, utc_timestamp_to_datetime class DatabaseStore(BatchStore): _BATCH_STATUS_OPEN = 0 _BATCH_STATUS_CLOSED = 128 _COMMAND_STATUS_PLAN = 0 _COMMAND_STATUS_EDIT = 1 _COMMAND_STATUS_NOOP = 2 _COMMAND_STATUS_PENDING = 16 _COMMAND_STATUS_PAGE_MISSING = 129 _COMMAND_STATUS_EDIT_CONFLICT = 130 _COMMAND_STATUS_MAXLAG_EXCEEDED = 131 _COMMAND_STATUS_BLOCKED = 132 _COMMAND_STATUS_WIKI_READ_ONLY = 133 _COMMAND_STATUS_PAGE_PROTECTED = 134 _COMMAND_STATUS_TITLE_INVALID = 135 def __init__(self, connection_params: dict) -> None: connection_params.setdefault('charset', 'utf8mb4') self.connection_params = connection_params self.domain_store = StringTableStore('domain', 'domain_id', 'domain_hash', 'domain_name') self.title_store = StringTableStore('title', 'title_id', 'title_hash', 'title_text') self.actions_store = StringTableStore('actions', 'actions_id', 'actions_hash', 'actions_tpsv') self.local_user_store = _LocalUserStore(self.domain_store) @contextlib.contextmanager def connect(self) -> Generator[pymysql.connections.Connection, None, None]: connection = pymysql.connect(cursorclass=QueryTimingCursor, **self.connection_params) try: yield connection finally: connection.close() @contextlib.contextmanager def connect_streaming(self) -> Generator[pymysql.connections.Connection, None, None]: connection = pymysql.connect(cursorclass=QueryTimingSSCursor, **self.connection_params) try: yield connection finally: connection.close() def store_batch(self, new_batch: NewBatch, session: mwapi.Session) -> OpenBatch: created = now() created_utc_timestamp = datetime_to_utc_timestamp(created) local_user = _local_user_from_session(session) with self.connect() as connection: domain_id = self.domain_store.acquire_id(connection, local_user.domain) if new_batch.title: title_id = self.title_store.acquire_id(connection, new_batch.title) else: title_id = None localuser_id = self.local_user_store.acquire_localuser_id(connection, local_user) with connection.cursor() as cursor: cursor.execute('INSERT INTO `batch` (`batch_localuser`, `batch_domain`, `batch_title`, `batch_created_utc_timestamp`, `batch_last_updated_utc_timestamp`, `batch_status`) VALUES (%s, %s, %s, %s, %s, %s)', (localuser_id, domain_id, title_id, created_utc_timestamp, created_utc_timestamp, DatabaseStore._BATCH_STATUS_OPEN)) batch_id = cursor.lastrowid with connection.cursor() as cursor: cursor.executemany('INSERT INTO `command` (`command_batch`, `command_page_title`, `command_page_resolve_redirects`, `command_actions`, `command_status`, `command_outcome`) VALUES (%s, %s, %s, %s, %s, NULL)', [(batch_id, command.page.title, command.page.resolve_redirects, self.actions_store.acquire_id(connection, command.actions_tpsv()), DatabaseStore._COMMAND_STATUS_PLAN) for command in new_batch.commands]) connection.commit() return OpenBatch(batch_id, local_user, local_user.domain, new_batch.title, created, created, _BatchCommandRecordsDatabase(batch_id, self), _BatchBackgroundRunsDatabase(batch_id, local_user.domain, self)) def get_batch(self, id: int) -> Optional[StoredBatch]: with self.connect() as connection: with connection.cursor() as cursor: cursor.execute('''SELECT `batch_id`, `localuser_user_name`, `localuser_local_user_id`, `localuser_global_user_id`, `domain_name`, `title_text`, `batch_created_utc_timestamp`, `batch_last_updated_utc_timestamp`, `batch_status` FROM `batch` JOIN `domain` ON `batch_domain` = `domain_id` JOIN `localuser` ON `batch_localuser` = `localuser_id` LEFT JOIN `title` ON `batch_title` = `title_id` WHERE `batch_id` = %s''', (id,)) result = cursor.fetchone() if not result: return None return self._result_to_batch(result) def _result_to_batch(self, result: tuple) -> StoredBatch: id, user_name, local_user_id, global_user_id, domain, title, created_utc_timestamp, last_updated_utc_timestamp, status = result created = utc_timestamp_to_datetime(created_utc_timestamp) last_updated = utc_timestamp_to_datetime(last_updated_utc_timestamp) local_user = LocalUser(user_name, domain, local_user_id, global_user_id) if status == DatabaseStore._BATCH_STATUS_OPEN: return OpenBatch(id, local_user, local_user.domain, title, created, last_updated, _BatchCommandRecordsDatabase(id, self), _BatchBackgroundRunsDatabase(id, local_user.domain, self)) elif status == DatabaseStore._BATCH_STATUS_CLOSED: return ClosedBatch(id, local_user, local_user.domain, title, created, last_updated, _BatchCommandRecordsDatabase(id, self), _BatchBackgroundRunsDatabase(id, local_user.domain, self)) else: raise ValueError('Unknown batch type') def get_batches_slice(self, offset: int, limit: int) -> Sequence[StoredBatch]: with self.connect() as connection: with connection.cursor() as cursor: cursor.execute('''SELECT `batch_id`, `localuser_user_name`, `localuser_local_user_id`, `localuser_global_user_id`, `domain_name`, `title_text`, `batch_created_utc_timestamp`, `batch_last_updated_utc_timestamp`, `batch_status` FROM `batch` JOIN `domain` ON `batch_domain` = `domain_id` JOIN `localuser` ON `batch_localuser` = `localuser_id` LEFT JOIN `title` ON `batch_title` = `title_id` ORDER BY `batch_id` DESC LIMIT %s OFFSET %s''', (limit, offset)) return [self._result_to_batch(result) for result in cursor.fetchall()] def get_batches_count(self) -> int: with self.connect() as connection: with connection.cursor() as cursor: cursor.execute('''SELECT COUNT(*) AS `count` FROM `batch`''') result = cursor.fetchone() assert result, "COUNT(*) must return a result" (count,) = result return count def start_background(self, batch: OpenBatch, session: mwapi.Session) -> None: started = now() started_utc_timestamp = datetime_to_utc_timestamp(started) local_user = _local_user_from_session(session) with self.connect() as connection: with connection.cursor() as cursor: cursor.execute('''SELECT 1 FROM `background` WHERE `background_batch` = %s AND `background_stopped_utc_timestamp` IS NULL FOR UPDATE''', (batch.id,)) if cursor.fetchone(): connection.commit() # finish the FOR UPDATE return assert isinstance(session.session.auth, requests_oauthlib.OAuth1) auth = {'resource_owner_key': session.session.auth.client.resource_owner_key, 'resource_owner_secret': session.session.auth.client.resource_owner_secret} localuser_id = self.local_user_store.acquire_localuser_id(connection, local_user) with connection.cursor() as cursor: cursor.execute('''INSERT INTO `background` (`background_batch`, `background_auth`, `background_started_utc_timestamp`, `background_started_localuser`) VALUES (%s, %s, %s, %s)''', (batch.id, json.dumps(auth), started_utc_timestamp, localuser_id)) connection.commit() def stop_background(self, batch: StoredBatch, session: Optional[mwapi.Session] = None) -> None: self._stop_background_by_id(batch.id, session) def _stop_background_by_id(self, batch_id: int, session: Optional[mwapi.Session] = None) -> None: stopped = now() stopped_utc_timestamp = datetime_to_utc_timestamp(stopped) with self.connect() as connection: localuser_id: Optional[int] if session: local_user = _local_user_from_session(session) localuser_id = self.local_user_store.acquire_localuser_id(connection, local_user) else: localuser_id = None with connection.cursor() as cursor: cursor.execute('''UPDATE `background` SET `background_auth` = NULL, `background_stopped_utc_timestamp` = %s, `background_stopped_localuser` = %s, `background_suspended_until_utc_timestamp` = NULL WHERE `background_batch` = %s AND `background_stopped_utc_timestamp` IS NULL''', (stopped_utc_timestamp, localuser_id, batch_id)) connection.commit() if cursor.rowcount > 1: raise RuntimeError('Should have stopped at most 1 background operation, actually affected %d!' % cursor.rowcount) def suspend_background(self, batch: StoredBatch, until: datetime.datetime) -> None: until_utc_timestamp = datetime_to_utc_timestamp(until) with self.connect() as connection, connection.cursor() as cursor: cursor.execute('''UPDATE `background` SET `background_suspended_until_utc_timestamp` = %s WHERE `background_batch` = %s AND `background_stopped_utc_timestamp` IS NULL''', (until_utc_timestamp, batch.id)) connection.commit() if cursor.rowcount > 1: raise RuntimeError('Should have suspended at most 1 background run, actually affected %d!' % cursor.rowcount) def make_plan_pending_background(self, consumer_token: mwoauth.ConsumerToken, user_agent: str) -> Optional[Tuple[OpenBatch, CommandPending, mwapi.Session]]: with self.connect() as connection: # find a planned command and lock it with connection.cursor() as cursor: now_utc_timestamp = datetime_to_utc_timestamp(now()) cursor.execute('''SELECT `command_id`, `batch_id` FROM `background` JOIN `batch` ON `background_batch` = `batch_id` JOIN `command` ON `command_batch` = `batch_id` WHERE `background_stopped_utc_timestamp` IS NULL AND COALESCE(`background_suspended_until_utc_timestamp`, 0) < %s AND `command_status` = %s ORDER BY `batch_last_updated_utc_timestamp` ASC, `command_id` ASC LIMIT 1 FOR UPDATE''', (now_utc_timestamp, DatabaseStore._COMMAND_STATUS_PLAN)) result = cursor.fetchone() if not result: connection.commit() # finish the FOR UPDATE return None command_id = result[0] batch_id = result[1] # make it pending with connection.cursor() as cursor: cursor.execute('''UPDATE `command` SET `command_status` = %s WHERE `command_id` = %s AND `command_batch` = %s''', (DatabaseStore._COMMAND_STATUS_PENDING, command_id, batch_id)) connection.commit() # get the rest of the data now that we know we need it (without locking it) with connection.cursor() as cursor: cursor.execute('''SELECT `batch_id`, `localuser_user_name`, `localuser_local_user_id`, `localuser_global_user_id`, `domain_name`, `title_text`, `batch_created_utc_timestamp`, `batch_last_updated_utc_timestamp`, `batch_status`, `background_auth`, `command_id`, `command_page_title`, `command_page_resolve_redirects`, `actions_tpsv` FROM `background` JOIN `batch` ON `background_batch` = `batch_id` JOIN `command` ON `command_batch` = `batch_id` JOIN `domain` ON `batch_domain` = `domain_id` JOIN `actions` ON `command_actions` = `actions_id` JOIN `localuser` ON `batch_localuser` = `localuser_id` LEFT JOIN `title` ON `batch_title` = `title_id` WHERE `command_id` = %s''', (command_id)) assert cursor.rowcount == 1 result = cursor.fetchone() assert result is not None auth_data = json.loads(result[9]) auth = requests_oauthlib.OAuth1(client_key=consumer_token.key, client_secret=consumer_token.secret, resource_owner_key=auth_data['resource_owner_key'], resource_owner_secret=auth_data['resource_owner_secret']) session = mwapi.Session(host='https://'+result[4], auth=auth, user_agent=user_agent) command_pending = self._row_to_command_record(result[10], result[11], result[12], result[13], DatabaseStore._COMMAND_STATUS_PENDING, outcome=None) batch = self._result_to_batch(result[0:9]) assert isinstance(batch, OpenBatch), "must be open since at least one command is still pending" assert isinstance(command_pending, CommandPending), "must be pending since we just set that status" return batch, command_pending, session def _command_finish_to_row(self, command_finish: CommandFinish) -> Tuple[int, dict]: status: int outcome: dict if isinstance(command_finish, CommandEdit): status = DatabaseStore._COMMAND_STATUS_EDIT outcome = {'base_revision': command_finish.base_revision, 'revision': command_finish.revision} elif isinstance(command_finish, CommandNoop): status = DatabaseStore._COMMAND_STATUS_NOOP outcome = {'revision': command_finish.revision} elif isinstance(command_finish, CommandPageMissing): status = DatabaseStore._COMMAND_STATUS_PAGE_MISSING outcome = {'curtimestamp': command_finish.curtimestamp} elif isinstance(command_finish, CommandTitleInvalid): status = DatabaseStore._COMMAND_STATUS_TITLE_INVALID outcome = {'curtimestamp': command_finish.curtimestamp} elif isinstance(command_finish, CommandPageProtected): status = DatabaseStore._COMMAND_STATUS_PAGE_PROTECTED outcome = {'curtimestamp': command_finish.curtimestamp} elif isinstance(command_finish, CommandEditConflict): status = DatabaseStore._COMMAND_STATUS_EDIT_CONFLICT outcome = {} elif isinstance(command_finish, CommandMaxlagExceeded): status = DatabaseStore._COMMAND_STATUS_MAXLAG_EXCEEDED outcome = {'retry_after_utc_timestamp': datetime_to_utc_timestamp(command_finish.retry_after)} elif isinstance(command_finish, CommandBlocked): status = DatabaseStore._COMMAND_STATUS_BLOCKED outcome = {'auto': command_finish.auto, 'blockinfo': command_finish.blockinfo} elif isinstance(command_finish, CommandWikiReadOnly): status = DatabaseStore._COMMAND_STATUS_WIKI_READ_ONLY outcome = {'reason': command_finish.reason} if command_finish.retry_after: outcome['retry_after_utc_timestamp'] = datetime_to_utc_timestamp(command_finish.retry_after) else: raise ValueError('Unknown command type') return status, outcome def _row_to_page(self, title: str, resolve_redirects: Optional[int]) -> Page: return Page(title, self._tinyint_to_bool(resolve_redirects)) def _row_to_command(self, title: str, resolve_redirects: Optional[int], actions_tpsv: str) -> Command: return Command(self._row_to_page(title, self._tinyint_to_bool(resolve_redirects)), [parse_tpsv.parse_action(field) for field in actions_tpsv.split('|')]) def _row_to_command_record(self, id: int, title: str, resolve_redirects: Optional[int], actions_tpsv: str, status: int, outcome: Optional[str]) -> CommandRecord: if outcome: outcome_dict = json.loads(outcome) command = self._row_to_command(title, resolve_redirects, actions_tpsv) if status == DatabaseStore._COMMAND_STATUS_PLAN: assert outcome is None return CommandPlan(id, command) elif status == DatabaseStore._COMMAND_STATUS_EDIT: return CommandEdit(id, command, base_revision=outcome_dict['base_revision'], revision=outcome_dict['revision']) elif status == DatabaseStore._COMMAND_STATUS_NOOP: return CommandNoop(id, command, revision=outcome_dict['revision']) elif status == DatabaseStore._COMMAND_STATUS_PENDING: assert outcome is None return CommandPending(id, command) elif status == DatabaseStore._COMMAND_STATUS_PAGE_MISSING: return CommandPageMissing(id, command, curtimestamp=outcome_dict['curtimestamp']) elif status == DatabaseStore._COMMAND_STATUS_TITLE_INVALID: return CommandTitleInvalid(id, command, curtimestamp=outcome_dict['curtimestamp']) elif status == DatabaseStore._COMMAND_STATUS_PAGE_PROTECTED: return CommandPageProtected(id, command, curtimestamp=outcome_dict['curtimestamp']) elif status == DatabaseStore._COMMAND_STATUS_EDIT_CONFLICT: return CommandEditConflict(id, command) elif status == DatabaseStore._COMMAND_STATUS_MAXLAG_EXCEEDED: return CommandMaxlagExceeded(id, command, utc_timestamp_to_datetime(outcome_dict['retry_after_utc_timestamp'])) elif status == DatabaseStore._COMMAND_STATUS_BLOCKED: return CommandBlocked(id, command, auto=outcome_dict['auto'], blockinfo=outcome_dict['blockinfo']) elif status == DatabaseStore._COMMAND_STATUS_WIKI_READ_ONLY: retry_after: Optional[datetime.datetime] if 'retry_after_utc_timestamp' in outcome_dict: retry_after = utc_timestamp_to_datetime(outcome_dict['retry_after_utc_timestamp']) else: retry_after = None return CommandWikiReadOnly(id, command, outcome_dict['reason'], retry_after) else: raise ValueError('Unknown command status %d' % status) def _status_to_command_record_type(self, status: int) -> Type[CommandRecord]: if status == DatabaseStore._COMMAND_STATUS_PLAN: return CommandPlan elif status == DatabaseStore._COMMAND_STATUS_EDIT: return CommandEdit elif status == DatabaseStore._COMMAND_STATUS_NOOP: return CommandNoop elif status == DatabaseStore._COMMAND_STATUS_PENDING: return CommandPending elif status == DatabaseStore._COMMAND_STATUS_PAGE_MISSING: return CommandPageMissing elif status == DatabaseStore._COMMAND_STATUS_TITLE_INVALID: return CommandTitleInvalid elif status == DatabaseStore._COMMAND_STATUS_PAGE_PROTECTED: return CommandPageProtected elif status == DatabaseStore._COMMAND_STATUS_EDIT_CONFLICT: return CommandEditConflict elif status == DatabaseStore._COMMAND_STATUS_MAXLAG_EXCEEDED: return CommandMaxlagExceeded elif status == DatabaseStore._COMMAND_STATUS_BLOCKED: return CommandBlocked elif status == DatabaseStore._COMMAND_STATUS_WIKI_READ_ONLY: return CommandWikiReadOnly else: raise ValueError('Unknown command status %d' % status) def _tinyint_to_bool(self, val: Optional[int]) -> Optional[bool]: if val is None: return None else: return bool(val) -@dataclass +@dataclass(frozen=True) class _BatchCommandRecordsDatabase(BatchCommandRecords): batch_id: int store: DatabaseStore def get_slice(self, offset: int, limit: int) -> List[CommandRecord]: command_records = [] with self.store.connect() as connection, connection.cursor() as cursor: cursor.execute('''SELECT `command_id`, `command_page_title`, `command_page_resolve_redirects`, `actions_tpsv`, `command_status`, `command_outcome` FROM `command` JOIN `actions` ON `command_actions` = `actions_id` WHERE `command_batch` = %s ORDER BY `command_id` ASC LIMIT %s OFFSET %s''', (self.batch_id, limit, offset)) for id, title, resolve_redirects, actions_tpsv, status, outcome in cursor.fetchall(): command_records.append(self.store._row_to_command_record(id, title, resolve_redirects, actions_tpsv, status, outcome)) return command_records def get_summary(self) -> Dict[Type[CommandRecord], int]: with self.store.connect() as connection, connection.cursor() as cursor: cursor.execute('''SELECT `command_status`, COUNT(*) AS `count` FROM `command` WHERE `command_batch` = %s GROUP BY `command_status`''', (self.batch_id,)) return {self.store._status_to_command_record_type(status): count for status, count in cursor.fetchall()} def stream_pages(self) -> Iterator[Page]: with self.store.connect_streaming() as connection, cast(pymysql.cursors.SSCursor, connection.cursor()) as cursor: cursor.execute('''SELECT `command_page_title`, `command_page_resolve_redirects` FROM `command` WHERE `command_batch` = %s ORDER BY `command_id` ASC''', (self.batch_id,)) for title, resolve_redirects in cursor.fetchall_unbuffered(): yield self.store._row_to_page(title, resolve_redirects) def stream_commands(self) -> Iterator[Command]: with self.store.connect_streaming() as connection, cast(pymysql.cursors.SSCursor, connection.cursor()) as cursor: cursor.execute('''SELECT `command_page_title`, `command_page_resolve_redirects`, `actions_tpsv` FROM `command` JOIN `actions` ON `command_actions` = `actions_id` WHERE `command_batch` = %s ORDER BY `command_id` ASC''', (self.batch_id,)) for title, resolve_redirects, actions_tpsv in cursor.fetchall_unbuffered(): yield self.store._row_to_command(title, resolve_redirects, actions_tpsv) def __len__(self) -> int: with self.store.connect() as connection, connection.cursor() as cursor: cursor.execute('SELECT COUNT(*) FROM `command` WHERE `command_batch` = %s', (self.batch_id,)) result = cursor.fetchone() assert result, "COUNT(*) must return a result" (count,) = result return count def make_plans_pending(self, offset: int, limit: int) -> List[CommandPending]: with self.store.connect() as connection: command_ids: List[int] = [] with connection.cursor() as cursor: # the extra subquery layer below is necessary to work around a MySQL/MariaDB restriction; # based on https://stackoverflow.com/a/24777566/1420237 cursor.execute('''SELECT `command_id` FROM `command` WHERE `command_id` IN ( SELECT * FROM ( SELECT `command_id` FROM `command` WHERE `command_batch` = %s ORDER BY `command_id` ASC LIMIT %s OFFSET %s ) AS temporary_table) AND `command_status` = %s ORDER BY `command_id` ASC FOR UPDATE''', (self.batch_id, limit, offset, DatabaseStore._COMMAND_STATUS_PLAN)) for (command_id,) in cursor.fetchall(): command_ids.append(command_id) if not command_ids: connection.commit() # finish the FOR UPDATE return [] with connection.cursor() as cursor: cursor.executemany('''UPDATE `command` SET `command_status` = %s WHERE `command_id` = %s AND `command_batch` = %s''', zip(itertools.repeat(DatabaseStore._COMMAND_STATUS_PENDING), command_ids, itertools.repeat(self.batch_id))) connection.commit() command_records = [] with connection.cursor() as cursor: cursor.execute('''SELECT `command_id`, `command_page_title`, `command_page_resolve_redirects`, `actions_tpsv`, `command_status`, `command_outcome` FROM `command` JOIN `actions` ON `command_actions` = `actions_id` WHERE `command_id` IN (%s)''' % ', '.join(['%s'] * len(command_ids)), command_ids) for id, title, resolve_redirects, actions_tpsv, status, outcome in cursor.fetchall(): assert status == DatabaseStore._COMMAND_STATUS_PENDING assert outcome is None command_record = self.store._row_to_command_record(id, title, resolve_redirects, actions_tpsv, status, outcome) assert isinstance(command_record, CommandPending) command_records.append(command_record) return command_records def make_pendings_planned(self, command_record_ids: List[int]) -> None: if not command_record_ids: return with self.store.connect() as connection, connection.cursor() as cursor: cursor.execute('''UPDATE `command` SET `command_status` = %%s WHERE `command_id` IN (%s) AND `command_status` = %%s''' % ', '.join(['%s'] * len(command_record_ids)), [DatabaseStore._COMMAND_STATUS_PLAN, *command_record_ids, DatabaseStore._COMMAND_STATUS_PENDING]) connection.commit() def store_finish(self, command_finish: CommandFinish) -> None: last_updated = now() last_updated_utc_timestamp = datetime_to_utc_timestamp(last_updated) status, outcome = self.store._command_finish_to_row(command_finish) with self.store.connect() as connection, connection.cursor() as cursor: cursor.execute('''UPDATE `command` SET `command_status` = %s, `command_outcome` = %s WHERE `command_id` = %s AND `command_batch` = %s''', (status, json.dumps(outcome), command_finish.id, self.batch_id)) cursor.execute('''UPDATE `batch` SET `batch_last_updated_utc_timestamp` = %s WHERE `batch_id` = %s''', (last_updated_utc_timestamp, self.batch_id)) connection.commit() if isinstance(command_finish, CommandFailure) and \ command_finish.can_retry_later(): # append a fresh plan for the same command cursor.execute('''INSERT INTO `command` (`command_batch`, `command_page_title`, `command_page_resolve_redirects`, `command_actions`, `command_status`, `command_outcome`) SELECT `command_batch`, `command_page_title`, `command_page_resolve_redirects`, `command_actions`, %s, NULL FROM `command` WHERE `command_id` = %s''', (DatabaseStore._COMMAND_STATUS_PLAN, command_finish.id)) command_plan_id = cursor.lastrowid cursor.execute('''INSERT INTO `retry` (`retry_failure`, `retry_new`) VALUES (%s, %s)''', (command_finish.id, command_plan_id)) connection.commit() else: # close the batch if no planned or pending commands are left in it cursor.execute('''SELECT 1 FROM `command` WHERE `command_batch` = %s AND `command_status` IN (%s, %s) LIMIT 1''', (self.batch_id, DatabaseStore._COMMAND_STATUS_PLAN, DatabaseStore._COMMAND_STATUS_PENDING)) if not cursor.fetchone(): cursor.execute('''UPDATE `batch` SET `batch_status` = %s WHERE `batch_id` = %s''', (DatabaseStore._BATCH_STATUS_CLOSED, self.batch_id)) connection.commit() self.store._stop_background_by_id(self.batch_id) def __eq__(self, value: Any) -> bool: # limited test to avoid overly expensive full comparison return type(value) is _BatchCommandRecordsDatabase and \ self.batch_id == value.batch_id -@dataclass +@dataclass(frozen=True) class _BatchBackgroundRunsDatabase(BatchBackgroundRuns): batch_id: int domain: str store: DatabaseStore def currently_running(self) -> bool: with self.store.connect() as connection, connection.cursor() as cursor: cursor.execute('''SELECT 1 FROM `background` WHERE `background_batch` = %s AND `background_stopped_utc_timestamp` IS NULL LIMIT 1''', (self.batch_id,)) return cursor.fetchone() is not None def _row_to_background_run(self, started_utc_timestamp: int, started_user_name: str, started_local_user_id: int, started_global_user_id: int, stopped_utc_timestamp: int, stopped_user_name: str, stopped_local_user_id: int, stopped_global_user_id: int) \ -> Tuple[Tuple[datetime.datetime, LocalUser], Optional[Tuple[datetime.datetime, Optional[LocalUser]]]]: # NOQA: E127 (indentation) background_start = (utc_timestamp_to_datetime(started_utc_timestamp), LocalUser(started_user_name, self.domain, started_local_user_id, started_global_user_id)) background_stop: Optional[Tuple[datetime.datetime, Optional[LocalUser]]] if stopped_utc_timestamp: stopped_local_user: Optional[LocalUser] if stopped_user_name: stopped_local_user = LocalUser(stopped_user_name, self.domain, stopped_local_user_id, stopped_global_user_id) else: stopped_local_user = None background_stop = (utc_timestamp_to_datetime(stopped_utc_timestamp), stopped_local_user) else: background_stop = None return (background_start, background_stop) def get_last(self) -> Optional[Tuple[Tuple[datetime.datetime, LocalUser], Optional[Tuple[datetime.datetime, Optional[LocalUser]]]]]: with self.store.connect() as connection, connection.cursor() as cursor: cursor.execute('''SELECT `background_started_utc_timestamp`, `started`.`localuser_user_name`, `started`.`localuser_local_user_id`, `started`.`localuser_global_user_id`, `background_stopped_utc_timestamp`, `stopped`.`localuser_user_name`, `stopped`.`localuser_local_user_id`, `stopped`.`localuser_global_user_id` FROM `background` JOIN `localuser` AS `started` ON `background_started_localuser` = `started`.`localuser_id` LEFT JOIN `localuser` AS `stopped` ON `background_stopped_localuser` = `stopped`.`localuser_id` WHERE `background_batch` = %s ORDER BY `background_id` DESC LIMIT 1''', (self.batch_id,)) result = cursor.fetchone() if result: return self._row_to_background_run(*result) else: return None def get_all(self) -> Sequence[Tuple[Tuple[datetime.datetime, LocalUser], Optional[Tuple[datetime.datetime, Optional[LocalUser]]]]]: with self.store.connect() as connection, connection.cursor() as cursor: cursor.execute('''SELECT `background_started_utc_timestamp`, `started`.`localuser_user_name`, `started`.`localuser_local_user_id`, `started`.`localuser_global_user_id`, `background_stopped_utc_timestamp`, `stopped`.`localuser_user_name`, `stopped`.`localuser_local_user_id`, `stopped`.`localuser_global_user_id` FROM `background` JOIN `localuser` AS `started` ON `background_started_localuser` = `started`.`localuser_id` LEFT JOIN `localuser` AS `stopped` ON `background_stopped_localuser` = `stopped`.`localuser_id` WHERE `background_batch` = %s ORDER BY `background_id` ASC''', (self.batch_id,)) return [self._row_to_background_run(*row) for row in cursor.fetchall()] def __eq__(self, value: Any) -> bool: # limited test to avoid overly expensive full comparison return type(value) is _BatchBackgroundRunsDatabase and \ self.batch_id == value.batch_id -@dataclass +@dataclass(frozen=True) class _LocalUserStore: """Encapsulates access to a local user account in the localuser table. When a user has been renamed, the same ID will still be used and the name will be updated once the user is stored the next time.""" domain_store: StringTableStore def acquire_localuser_id(self, connection: pymysql.connections.Connection, local_user: LocalUser) -> int: domain_id = self.domain_store.acquire_id(connection, local_user.domain) with connection.cursor() as cursor: cursor.execute('''INSERT INTO `localuser` (`localuser_user_name`, `localuser_domain`, `localuser_local_user_id`, `localuser_global_user_id`) VALUES (%s, %s, %s, %s) ON DUPLICATE KEY UPDATE `localuser_user_name` = %s''', (local_user.user_name, domain_id, local_user.local_user_id, local_user.global_user_id, local_user.user_name)) localuser_id = cursor.lastrowid if not localuser_id: # not returned in the ON DUPLICATE KEY UPDATE case, apparently cursor.execute('''SELECT `localuser_id` FROM `localuser` WHERE `localuser_local_user_id` = %s AND `localuser_domain` = %s''', (local_user.local_user_id, domain_id)) result = cursor.fetchone() assert result, "COUNT(*) must return a result" (localuser_id,) = result assert cursor.fetchone() is None connection.commit() return localuser_id diff --git a/in_memory.py b/in_memory.py index f9adf48..c80a8ed 100644 --- a/in_memory.py +++ b/in_memory.py @@ -1,203 +1,203 @@ from dataclasses import dataclass import datetime import mwapi # type: ignore import mwoauth # type: ignore from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Type, cast from batch import NewBatch, StoredBatch, OpenBatch, ClosedBatch from batch_background_runs import BatchBackgroundRuns from batch_command_records import BatchCommandRecords from command import Command, CommandPlan, CommandPending, CommandRecord, CommandFinish, CommandFailure from localuser import LocalUser from page import Page from store import BatchStore, _local_user_from_session from timestamp import now class InMemoryStore(BatchStore): def __init__(self) -> None: self.next_batch_id = 1 self.next_command_id = 1 self.batches: Dict[int, StoredBatch] = {} self.background_sessions: Dict[int, mwapi.Session] = {} self.background_suspensions: Dict[int, datetime.datetime] = {} def store_batch(self, new_batch: NewBatch, session: mwapi.Session) -> OpenBatch: created = now() local_user = _local_user_from_session(session) command_plans: List[CommandRecord] = [] for command in new_batch.commands: command_plans.append(CommandPlan(self.next_command_id, command)) self.next_command_id += 1 open_batch = OpenBatch(self.next_batch_id, local_user, local_user.domain, new_batch.title, created, created, _BatchCommandRecordsList(command_plans, self.next_batch_id, self), _BatchBackgroundRunsList([], self)) self.next_batch_id += 1 self.batches[open_batch.id] = open_batch return open_batch def get_batch(self, id: int) -> Optional[StoredBatch]: stored_batch = self.batches.get(id) if stored_batch is None: return None command_records = cast(_BatchCommandRecordsList, stored_batch.command_records).command_records if isinstance(stored_batch, OpenBatch) and \ all(map(lambda command_record: isinstance(command_record, CommandFinish), command_records)): stored_batch = ClosedBatch(stored_batch.id, stored_batch.local_user, stored_batch.domain, stored_batch.title, stored_batch.created, stored_batch.last_updated, stored_batch.command_records, stored_batch.background_runs) self.batches[id] = stored_batch return stored_batch def get_batches_slice(self, offset: int, limit: int) -> Sequence[StoredBatch]: return [cast(StoredBatch, self.get_batch(id)) for id in sorted(self.batches.keys(), reverse=True)[offset:offset+limit]] def get_batches_count(self) -> int: return len(self.batches) def start_background(self, batch: OpenBatch, session: mwapi.Session) -> None: started = now() local_user = _local_user_from_session(session) background_runs = cast(_BatchBackgroundRunsList, batch.background_runs) if not background_runs.currently_running(): background_runs.background_runs.append(((started, local_user), None)) self.background_sessions[batch.id] = session def stop_background(self, batch: StoredBatch, session: Optional[mwapi.Session] = None) -> None: stopped = now() local_user: Optional[LocalUser] if session: local_user = _local_user_from_session(session) else: local_user = None background_runs = cast(_BatchBackgroundRunsList, batch.background_runs) if background_runs.currently_running(): background_runs.background_runs[-1] = (background_runs.background_runs[-1][0], (stopped, local_user)) del self.background_sessions[batch.id] self.background_suspensions.pop(batch.id, None) def suspend_background(self, batch: StoredBatch, until: datetime.datetime) -> None: self.background_suspensions[batch.id] = until def make_plan_pending_background(self, consumer_token: mwoauth.ConsumerToken, user_agent: str) -> Optional[Tuple[OpenBatch, CommandPending, mwapi.Session]]: batches_by_last_updated = sorted(self.batches.values(), key=lambda batch: batch.last_updated) for batch in batches_by_last_updated: if not batch.background_runs.currently_running(): continue if batch.id in self.background_suspensions: if self.background_suspensions[batch.id] < now(): del self.background_suspensions[batch.id] else: continue assert isinstance(batch, OpenBatch) assert isinstance(batch.command_records, _BatchCommandRecordsList) for index, command_plan in enumerate(batch.command_records.command_records): if not isinstance(command_plan, CommandPlan): continue command_pending = CommandPending(command_plan.id, command_plan.command) batch.command_records.command_records[index] = command_pending return batch, command_pending, self.background_sessions[batch.id] return None -@dataclass +@dataclass(frozen=True) class _BatchCommandRecordsList(BatchCommandRecords): command_records: List[CommandRecord] batch_id: int store: InMemoryStore def get_slice(self, offset: int, limit: int) -> List[CommandRecord]: return self.command_records[offset:offset+limit] def get_summary(self) -> Dict[Type[CommandRecord], int]: ret: Dict[Type[CommandRecord], int] = {} for command_record in self.command_records: t = type(command_record) ret[t] = ret.get(t, 0) + 1 return ret def stream_pages(self) -> Iterator[Page]: for command_record in self.command_records: yield command_record.command.page def stream_commands(self) -> Iterator[Command]: for command_record in self.command_records: yield command_record.command def make_plans_pending(self, offset: int, limit: int) -> List[CommandPending]: command_pendings = [] for index, command_plan in enumerate(self.command_records[offset:offset+limit]): if not isinstance(command_plan, CommandPlan): continue command_pending = CommandPending(command_plan.id, command_plan.command) self.command_records[index] = command_pending command_pendings.append(command_pending) return command_pendings def make_pendings_planned(self, command_record_ids: List[int]) -> None: for index, command_pending in enumerate(self.command_records): if not isinstance(command_pending, CommandPending): continue if command_pending.id not in command_record_ids: continue command_plan = CommandPlan(command_pending.id, command_pending.command) self.command_records[index] = command_plan def store_finish(self, command_finish: CommandFinish) -> None: for index, command_record in enumerate(self.command_records): if command_record.id == command_finish.id: self.command_records[index] = command_finish break else: raise KeyError('command not found') self.store.batches[self.batch_id].last_updated = now() if isinstance(command_finish, CommandFailure) and \ command_finish.can_retry_later(): # append a fresh plan for the same command command_plan = CommandPlan(self.store.next_command_id, command_finish.command) self.store.next_command_id += 1 self.command_records.append(command_plan) def __len__(self) -> int: return len(self.command_records) def __eq__(self, value: Any) -> bool: return type(value) is _BatchCommandRecordsList and \ self.command_records == value.command_records -@dataclass +@dataclass(frozen=True) class _BatchBackgroundRunsList(BatchBackgroundRuns): background_runs: List[Tuple[Tuple[datetime.datetime, LocalUser], Optional[Tuple[datetime.datetime, Optional[LocalUser]]]]] store: InMemoryStore def get_last(self) -> Optional[Tuple[Tuple[datetime.datetime, LocalUser], Optional[Tuple[datetime.datetime, Optional[LocalUser]]]]]: if self.background_runs: return self.background_runs[-1] else: return None def get_all(self) -> Sequence[Tuple[Tuple[datetime.datetime, LocalUser], Optional[Tuple[datetime.datetime, Optional[LocalUser]]]]]: return self.background_runs def __eq__(self, value: Any) -> bool: return type(value) is _BatchBackgroundRunsList and \ self.background_runs == value.background_runs diff --git a/localuser.py b/localuser.py index 8157968..fa4b16d 100644 --- a/localuser.py +++ b/localuser.py @@ -1,14 +1,14 @@ from dataclasses import dataclass, field -@dataclass +@dataclass(frozen=True) class LocalUser: """A user account local to one wiki.""" user_name: str = field(compare=False) # ignored in __eq__, to account for renamed users domain: str local_user_id: int global_user_id: int def __str__(self) -> str: return self.user_name + '@' + self.domain