diff --git a/maintenance/0004-timestamps-to-int.sql b/maintenance/0004-timestamps-to-int.sql new file mode 100644 index 0000000..9a5191e --- /dev/null +++ b/maintenance/0004-timestamps-to-int.sql @@ -0,0 +1,6 @@ +--- Change the batch_created_utc_timestamp and batch_last_updated_utc_timestamp columns of the batch table +--- from double to int, since we don’t need sub-second precision. + +ALTER TABLE batch +MODIFY batch_created_utc_timestamp int unsigned NOT NULL, +MODIFY batch_last_updated_utc_timestamp int unsigned NOT NULL; diff --git a/runner.py b/runner.py index c44e4c5..ea28bc4 100644 --- a/runner.py +++ b/runner.py @@ -1,112 +1,113 @@ import datetime import mwapi # type: ignore from typing import Dict, List, Optional from command import CommandPlan, CommandFinish, CommandEdit, CommandNoop, CommandPageMissing, CommandEditConflict, CommandMaxlagExceeded, CommandBlocked, CommandWikiReadOnly import siteinfo class Runner(): def __init__(self, session: mwapi.Session, summary_suffix: Optional[str] = None): self.session = session self.csrf_token = session.get(action='query', meta='tokens')['query']['tokens']['csrftoken'] self.summary_suffix = summary_suffix self.prepared_pages = {} # type: Dict[str, dict] def prepare_pages(self, titles: List[str]): assert titles assert len(titles) <= 50 response = self.session.get(action='query', titles=titles, prop=['revisions'], rvprop=['ids', 'content', 'contentmodel', 'timestamp'], rvslots=['main'], curtimestamp=True, formatversion=2) for page in response['query']['pages']: title = page['title'] if 'missing' in page: self.prepared_pages[title] = { 'missing': True, 'curtimestamp': response['curtimestamp'], } continue revision = page['revisions'][0] slot = revision['slots']['main'] if slot['contentmodel'] != 'wikitext' or slot['contentformat'] != 'text/x-wiki': raise ValueError('Unexpected content model or format for revision %d of page %s, refusing to edit!' % (revision['revid'], title)) original_wikitext = slot['content'] self.prepared_pages[title] = { 'wikitext': slot['content'], 'page_id': page['pageid'], 'base_timestamp': revision['timestamp'], 'base_revid': revision['revid'], 'start_timestamp': response['curtimestamp'], } for normalization in response['query'].get('normalized', {}): self.prepared_pages[normalization['from']] = self.prepared_pages[normalization['to']] def run_command(self, plan: CommandPlan) -> CommandFinish: title = plan.command.page if title not in self.prepared_pages: self.prepare_pages([title]) prepared_page = self.prepared_pages[title] category_info = siteinfo.category_info(self.session) if 'missing' in prepared_page: return CommandPageMissing(plan.id, plan.command, curtimestamp=prepared_page['curtimestamp']) wikitext, actions = plan.command.apply(prepared_page['wikitext'], category_info) summary = '' for action, noop in actions: action_summary = action.summary(category_info) if noop: action_summary = siteinfo.parentheses(self.session, action_summary) if summary: summary += siteinfo.comma_separator(self.session) summary += action_summary if self.summary_suffix: summary += siteinfo.semicolon_separator(self.session) summary += self.summary_suffix if wikitext == prepared_page['wikitext']: return CommandNoop(plan.id, plan.command, prepared_page['base_revid']) try: response = self.session.post(**{'action': 'edit', 'pageid': prepared_page['page_id'], 'text': wikitext, 'summary': summary, 'bot': True, 'basetimestamp': prepared_page['base_timestamp'], 'starttimestamp': prepared_page['start_timestamp'], 'contentformat': 'text/x-wiki', 'contentmodel': 'wikitext', 'token': self.csrf_token, 'assert': 'user', # assert is a keyword, can’t use kwargs syntax :( 'maxlag': 5, 'formatversion': 2}) except mwapi.errors.APIError as e: if e.code == 'editconflict': del self.prepared_pages[title] # this must be outdated now return CommandEditConflict(plan.id, plan.command) elif e.code == 'maxlag': retry_after_seconds = 5 # the API returns this in a Retry-After header, but mwapi hides that from us :( retry_after = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=retry_after_seconds) + retry_after = retry_after.replace(microsecond=0) return CommandMaxlagExceeded(plan.id, plan.command, retry_after) elif e.code == 'blocked' or e.code == 'autoblocked': auto = e.code == 'autoblocked' blockinfo = None # the API returns this in a 'blockinfo' member of the 'error' object, but mwapi hides that from us :( return CommandBlocked(plan.id, plan.command, auto, blockinfo) elif e.code == 'readonly': reason = None # the API returns this in a 'readonlyreason' member of the 'error' object, but mwapi hides that from us :( return CommandWikiReadOnly(plan.id, plan.command, reason) else: raise e assert response['edit']['oldrevid'] == prepared_page['base_revid'] return CommandEdit(plan.id, plan.command, response['edit']['oldrevid'], response['edit']['newrevid']) diff --git a/store.py b/store.py index 1b79c1f..a5b9d4d 100644 --- a/store.py +++ b/store.py @@ -1,374 +1,375 @@ import cachetools import contextlib import datetime import hashlib import json import mwapi # type: ignore import operator import pymysql import threading from typing import Any, Generator, Iterable, List, MutableSequence, Optional, Sequence, Tuple, Union, overload from batch import NewBatch, OpenBatch from command import Command, CommandPlan, CommandRecord, CommandFinish, CommandEdit, CommandNoop, CommandPageMissing, CommandEditConflict, CommandMaxlagExceeded, CommandBlocked, CommandWikiReadOnly import parse_tpsv def _metadata_from_session(session: mwapi.Session) -> Tuple[str, int, int, str]: domain = session.host[len('https://'):] response = session.get(action='query', meta='userinfo', uiprop='centralids') user_name = response['query']['userinfo']['name'] local_user_id = response['query']['userinfo']['id'] global_user_id = response['query']['userinfo']['centralids']['CentralAuth'] return user_name, local_user_id, global_user_id, domain def _now() -> datetime.datetime: return datetime.datetime.now(tz=datetime.timezone.utc).replace(microsecond=0) class BatchStore: def store_batch(self, new_batch: NewBatch, session: mwapi.Session) -> OpenBatch: ... def get_batch(self, id: int) -> Optional[OpenBatch]: ... def get_latest_batches(self) -> Sequence[OpenBatch]: ... class InMemoryStore(BatchStore): def __init__(self): self.next_batch_id = 1 self.next_command_id = 1 self.batches = {} def store_batch(self, new_batch: NewBatch, session: mwapi.Session) -> OpenBatch: command_plans = [] # type: List[CommandRecord] for command in new_batch.commands: command_plans.append(CommandPlan(self.next_command_id, command)) self.next_command_id += 1 user_name, local_user_id, global_user_id, domain = _metadata_from_session(session) created = _now() open_batch = OpenBatch(self.next_batch_id, user_name, local_user_id, global_user_id, domain, created, created, command_plans) self.next_batch_id += 1 self.batches[open_batch.id] = open_batch return open_batch def get_batch(self, id: int) -> Optional[OpenBatch]: return self.batches.get(id) def get_latest_batches(self) -> Sequence[OpenBatch]: return [self.batches[id] for id in sorted(self.batches.keys(), reverse=True)[:10]] class DatabaseStore(BatchStore): _BATCH_STATUS_OPEN = 0 _COMMAND_STATUS_PLAN = 0 _COMMAND_STATUS_EDIT = 1 _COMMAND_STATUS_NOOP = 2 _COMMAND_STATUS_PAGE_MISSING = 129 _COMMAND_STATUS_EDIT_CONFLICT = 130 _COMMAND_STATUS_MAXLAG_EXCEEDED = 131 _COMMAND_STATUS_BLOCKED = 132 _COMMAND_STATUS_WIKI_READ_ONLY = 133 def __init__(self, connection_params: dict): connection_params.setdefault('charset', 'utf8mb4') self.connection_params = connection_params self.domain_store = _StringTableStore('domain', 'domain_id', 'domain_hash', 'domain_name') self.actions_store = _StringTableStore('actions', 'actions_id', 'actions_hash', 'actions_tpsv') @contextlib.contextmanager def _connect(self) -> Generator[pymysql.connections.Connection, None, None]: connection = pymysql.connect(**self.connection_params) try: yield connection finally: connection.close() - def _datetime_to_utc_timestamp(self, dt: datetime.datetime) -> float: + def _datetime_to_utc_timestamp(self, dt: datetime.datetime) -> int: assert dt.tzinfo == datetime.timezone.utc - return dt.timestamp() + assert dt.microsecond == 0 + return int(dt.timestamp()) - def _utc_timestamp_to_datetime(self, timestamp: float) -> datetime.datetime: + def _utc_timestamp_to_datetime(self, timestamp: int) -> datetime.datetime: return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) def store_batch(self, new_batch: NewBatch, session: mwapi.Session) -> OpenBatch: user_name, local_user_id, global_user_id, domain = _metadata_from_session(session) created = _now() created_utc_timestamp = self._datetime_to_utc_timestamp(created) with self._connect() as connection: domain_id = self.domain_store.acquire_id(connection, domain) with connection.cursor() as cursor: cursor.execute('INSERT INTO `batch` (`batch_user_name`, `batch_local_user_id`, `batch_global_user_id`, `batch_domain_id`, `batch_created_utc_timestamp`, `batch_last_updated_utc_timestamp`, `batch_status`) VALUES (%s, %s, %s, %s, %s, %s, %s)', (user_name, local_user_id, global_user_id, domain_id, created_utc_timestamp, created_utc_timestamp, DatabaseStore._BATCH_STATUS_OPEN)) batch_id = cursor.lastrowid with connection.cursor() as cursor: cursor.executemany('INSERT INTO `command` (`command_batch`, `command_page`, `command_actions_id`, `command_status`, `command_outcome`) VALUES (%s, %s, %s, %s, NULL)', [(batch_id, command.page, self.actions_store.acquire_id(connection, command.actions_tpsv()), DatabaseStore._COMMAND_STATUS_PLAN) for command in new_batch.commands]) connection.commit() return OpenBatch(batch_id, user_name, local_user_id, global_user_id, domain, created, created, _DatabaseCommandRecords(batch_id, self)) def get_batch(self, id: int) -> Optional[OpenBatch]: with self._connect() as connection: with connection.cursor() as cursor: cursor.execute('''SELECT `batch_id`, `batch_user_name`, `batch_local_user_id`, `batch_global_user_id`, `domain_name`, `batch_created_utc_timestamp`, `batch_last_updated_utc_timestamp`, `batch_status` FROM `batch` JOIN `domain` ON `batch_domain_id` = `domain_id` WHERE `batch_id` = %s''', (id,)) result = cursor.fetchone() if not result: return None return self._result_to_batch(result) def _result_to_batch(self, result: tuple) -> OpenBatch: id, user_name, local_user_id, global_user_id, domain, created_utc_timestamp, last_updated_utc_timestamp, status = result assert status == DatabaseStore._BATCH_STATUS_OPEN created = self._utc_timestamp_to_datetime(created_utc_timestamp) last_updated = self._utc_timestamp_to_datetime(last_updated_utc_timestamp) return OpenBatch(id, user_name, local_user_id, global_user_id, domain, created, last_updated, _DatabaseCommandRecords(id, self)) def get_latest_batches(self) -> Sequence[OpenBatch]: with self._connect() as connection: with connection.cursor() as cursor: cursor.execute('''SELECT `batch_id`, `batch_user_name`, `batch_local_user_id`, `batch_global_user_id`, `domain_name`, `batch_created_utc_timestamp`, `batch_last_updated_utc_timestamp`, `batch_status` FROM `batch` JOIN `domain` ON `batch_domain_id` = `domain_id` ORDER BY `batch_id` DESC LIMIT 10'''); return [self._result_to_batch(result) for result in cursor.fetchall()] class _DatabaseCommandRecords(MutableSequence[CommandRecord]): def __init__(self, batch_id: int, store: DatabaseStore): self.batch_id = batch_id self.store = store def _command_record_to_row(self, command_record: CommandRecord) -> Tuple[int, dict]: if isinstance(command_record, CommandEdit): status = DatabaseStore._COMMAND_STATUS_EDIT outcome = {'base_revision': command_record.base_revision, 'revision': command_record.revision} # type: dict elif isinstance(command_record, CommandNoop): status = DatabaseStore._COMMAND_STATUS_NOOP outcome = {'revision': command_record.revision} elif isinstance(command_record, CommandPageMissing): status = DatabaseStore._COMMAND_STATUS_PAGE_MISSING outcome = {'curtimestamp': command_record.curtimestamp} elif isinstance(command_record, CommandEditConflict): status = DatabaseStore._COMMAND_STATUS_EDIT_CONFLICT outcome = {} elif isinstance(command_record, CommandMaxlagExceeded): status = DatabaseStore._COMMAND_STATUS_MAXLAG_EXCEEDED outcome = {'retry_after_utc_timestamp': self.store._datetime_to_utc_timestamp(command_record.retry_after)} elif isinstance(command_record, CommandBlocked): status = DatabaseStore._COMMAND_STATUS_BLOCKED outcome = {'auto': command_record.auto, 'blockinfo': command_record.blockinfo} elif isinstance(command_record, CommandWikiReadOnly): status = DatabaseStore._COMMAND_STATUS_WIKI_READ_ONLY outcome = {'reason': command_record.reason} else: raise ValueError('Unknown command type') return status, outcome def _row_to_command_record(self, id: int, page: str, actions_tpsv: str, status: int, outcome: Optional[str]) -> CommandRecord: if outcome: outcome_dict = json.loads(outcome) command = Command(page, [parse_tpsv.parse_action(field) for field in actions_tpsv.split('|')]) if status == DatabaseStore._COMMAND_STATUS_PLAN: assert outcome is None return CommandPlan(id, command) elif status == DatabaseStore._COMMAND_STATUS_EDIT: return CommandEdit(id, command, base_revision=outcome_dict['base_revision'], revision=outcome_dict['revision']) elif status == DatabaseStore._COMMAND_STATUS_NOOP: return CommandNoop(id, command, revision=outcome_dict['revision']) elif status == DatabaseStore._COMMAND_STATUS_PAGE_MISSING: return CommandPageMissing(id, command, curtimestamp=outcome_dict['curtimestamp']) elif status == DatabaseStore._COMMAND_STATUS_EDIT_CONFLICT: return CommandEditConflict(id, command) elif status == DatabaseStore._COMMAND_STATUS_MAXLAG_EXCEEDED: return CommandMaxlagExceeded(id, command, self.store._utc_timestamp_to_datetime(outcome_dict['retry_after_utc_timestamp'])) elif status == DatabaseStore._COMMAND_STATUS_BLOCKED: return CommandBlocked(id, command, auto=outcome_dict['auto'], blockinfo=outcome_dict['blockinfo']) elif status == DatabaseStore._COMMAND_STATUS_WIKI_READ_ONLY: return CommandWikiReadOnly(id, command, outcome_dict['reason']) else: raise ValueError('Unknown command status %d' % status) @overload def __getitem__(self, index: int) -> CommandRecord: ... @overload def __getitem__(self, index: slice) -> List[CommandRecord]: ... def __getitem__(self, index): if isinstance(index, int): index = slice(index, index + 1) return_first = True else: return_first = False assert isinstance(index, slice) assert isinstance(index.start, int) assert isinstance(index.stop, int) assert index.step in [None, 1] command_records = [] with self.store._connect() as connection, connection.cursor() as cursor: cursor.execute('''SELECT `command_id`, `command_page`, `actions_tpsv`, `command_status`, `command_outcome` FROM `command` JOIN `actions` ON `command_actions_id` = `actions_id` WHERE `command_batch` = %s ORDER BY `command_id` ASC LIMIT %s OFFSET %s''', (self.batch_id, index.stop - index.start, index.start)) for id, page, actions_tpsv, status, outcome in cursor.fetchall(): command_records.append(self._row_to_command_record(id, page, actions_tpsv, status, outcome)) if return_first: return command_records[0] else: return command_records def __len__(self) -> int: with self.store._connect() as connection, connection.cursor() as cursor: cursor.execute('SELECT COUNT(*) FROM `command` WHERE `command_batch` = %s', (self.batch_id,)) (count,) = cursor.fetchone() return count @overload def __setitem__(self, index: int, value: CommandRecord) -> None: ... @overload def __setitem__(self, index: slice, value: Iterable[CommandRecord]) -> None: ... def __setitem__(self, index, value): if isinstance(index, slice): raise NotImplementedError('Can only set a single command record') if isinstance(value, CommandPlan): raise NotImplementedError('Can only store finished commands') assert isinstance(index, int) assert isinstance(value, CommandFinish) status, outcome = self._command_record_to_row(value) last_updated = _now() last_updated_utc_timestamp = self.store._datetime_to_utc_timestamp(last_updated) with self.store._connect() as connection, connection.cursor() as cursor: cursor.execute('''UPDATE `command` SET `command_status` = %s, `command_outcome` = %s WHERE `command_id` = %s AND `command_batch` = %s''', (status, json.dumps(outcome), value.id, self.batch_id)) cursor.execute('''UPDATE `batch` SET `batch_last_updated_utc_timestamp` = %s WHERE `batch_id` = %s''', (last_updated_utc_timestamp, self.batch_id)) connection.commit() def __delitem__(self, *args, **kwargs): raise NotImplementedError('Cannot delete commands from a batch') def insert(self, *args, **kwargs): raise NotImplementedError('Cannot insert commands into a batch') def __eq__(self, value: Any) -> bool: # limited test to avoid overly expensive full comparison return type(value) is _DatabaseCommandRecords and \ self.batch_id == value.batch_id class _StringTableStore: """Encapsulates access to a string that has been extracted into a separate table. The separate table is expected to have three columns: an automatically incrementing ID, an unsigned integer hash (the first four bytes of the SHA2-256 hash of the string), and the string itself. IDs for the least recently used strings are cached, but to look up the string for an ID, callers should use a plain SQL JOIN for now.""" def __init__(self, table_name: str, id_column_name: str, hash_column_name: str, string_column_name: str): self.table_name = table_name self.id_column_name = id_column_name self.hash_column_name = hash_column_name self.string_column_name = string_column_name self._cache = cachetools.LRUCache(maxsize=1024) # type: cachetools.LRUCache[str, int] self._cache_lock = threading.RLock() def _hash(self, string: str) -> int: hex = hashlib.sha256(string.encode('utf8')).hexdigest() return int(hex[:8], base=16) @cachetools.cachedmethod(operator.attrgetter('_cache'), key=lambda connection, string: string, lock=operator.attrgetter('_cache_lock')) def acquire_id(self, connection: pymysql.connections.Connection, string: str) -> int: hash = self._hash(string) with connection.cursor() as cursor: cursor.execute('''SELECT `%s` FROM `%s` WHERE `%s` = %%s FOR UPDATE''' % (self.id_column_name, self.table_name, self.hash_column_name), (hash,)) result = cursor.fetchone() if result: connection.commit() # finish the FOR UPDATE return result[0] with connection.cursor() as cursor: cursor.execute('''INSERT INTO `%s` (`%s`, `%s`) VALUES (%%s, %%s)''' % (self.table_name, self.string_column_name, self.hash_column_name), (string, hash)) string_id = cursor.lastrowid connection.commit() return string_id diff --git a/tables.sql b/tables.sql index 072ae3d..7a2547d 100644 --- a/tables.sql +++ b/tables.sql @@ -1,45 +1,45 @@ CREATE TABLE batch ( batch_id int unsigned NOT NULL PRIMARY KEY AUTO_INCREMENT, batch_user_name varchar(255) binary NOT NULL, batch_local_user_id int unsigned NOT NULL, batch_global_user_id int unsigned NOT NULL, batch_domain_id int unsigned NOT NULL, - batch_created_utc_timestamp double unsigned NOT NULL, - batch_last_updated_utc_timestamp double unsigned NOT NULL, + batch_created_utc_timestamp int unsigned NOT NULL, + batch_last_updated_utc_timestamp int unsigned NOT NULL, batch_status int unsigned NOT NULL ) CHARACTER SET = 'utf8mb4' COLLATE = 'utf8mb4_bin'; CREATE TABLE command ( command_id int unsigned NOT NULL PRIMARY KEY AUTO_INCREMENT, command_batch int unsigned NOT NULL, command_page text NOT NULL, command_actions_id int unsigned NOT NULL, command_status int unsigned NOT NULL, command_outcome text ) CHARACTER SET = 'utf8mb4' COLLATE = 'utf8mb4_bin'; CREATE INDEX command_batch ON command (command_batch); CREATE TABLE domain ( domain_id int unsigned NOT NULL PRIMARY KEY AUTO_INCREMENT, domain_hash int unsigned NOT NULL, -- first four bytes of the SHA2-256 hash of the domain_name domain_name varchar(255) binary NOT NULL ) CHARACTER SET = 'utf8mb4' COLLATE = 'utf8mb4_bin'; CREATE INDEX domain_hash ON domain (domain_hash); CREATE TABLE actions ( actions_id int unsigned NOT NULL PRIMARY KEY AUTO_INCREMENT, actions_hash int unsigned NOT NULL, -- first four bytes of the SHA2-256 hash of the actions_tpsv actions_tpsv text NOT NULL ) CHARACTER SET = 'utf8mb4' COLLATE = 'utf8mb4_bin'; CREATE INDEX actions_hash ON actions (actions_hash); diff --git a/test_command.py b/test_command.py index 7b3e54b..197c419 100644 --- a/test_command.py +++ b/test_command.py @@ -1,302 +1,302 @@ import datetime import pytest from action import AddCategoryAction, RemoveCategoryAction from command import Command, CommandPlan, CommandEdit, CommandNoop, CommandPageMissing, CommandEditConflict, CommandMaxlagExceeded, CommandBlocked, CommandWikiReadOnly from test_action import addCategory1, removeCategory2, addCategory3 command1 = Command('Page 1', [addCategory1, removeCategory2]) command2 = Command('Page 2', [addCategory3]) def test_Command_apply(): wikitext = 'Test page for the QuickCategories tool.\n[[Category:Already present cat]]\n[[Category:Removed cat]]\nBottom text' command = Command('Page title', [AddCategoryAction('Added cat'), AddCategoryAction('Already present cat'), RemoveCategoryAction('Removed cat'), RemoveCategoryAction('Not present cat')]) new_wikitext, actions = command.apply(wikitext, ('Category', ['Category'], 'first-letter')) assert new_wikitext == 'Test page for the QuickCategories tool.\n[[Category:Already present cat]]\n[[Category:Added cat]]\nBottom text' assert actions == [(command.actions[0], False), (command.actions[1], True), (command.actions[2], False), (command.actions[3], True)] def test_Command_cleanup(): command = Command('Page_from_URL', [AddCategoryAction('Category_from_URL')]) command.cleanup() assert command == Command('Page from URL', [AddCategoryAction('Category from URL')]) def test_Command_actions_tpsv(): assert command1.actions_tpsv() == '+Category:Cat 1|-Category:Cat 2' def test_Command_eq_same(): assert command1 == command1 def test_Command_eq_equal(): assert command1 == Command(command1.page, command1.actions) def test_Command_eq_different_type(): assert command1 != addCategory1 assert command1 != None def test_Command_eq_different_page(): assert command1 != Command('Page A', command1.actions) assert command1 != Command('Page_1', command1.actions) def test_Command_eq_different_actions(): assert command1 != Command(command1.page, [addCategory1]) def test_Command_str(): assert str(command1) == 'Page 1|+Category:Cat 1|-Category:Cat 2' def test_Command_repr(): assert eval(repr(command1)) == command1 commandPlan1 = CommandPlan(42, command1) def test_CommandPlan_eq_same(): assert commandPlan1 == commandPlan1 def test_CommandPlan_eq_equal(): assert commandPlan1 == CommandPlan(commandPlan1.id, commandPlan1.command) def test_CommandPlan_eq_different_id(): assert commandPlan1 != CommandPlan(43, commandPlan1.command) def test_CommandPlan_eq_different_command(): assert commandPlan1 != CommandPlan(commandPlan1.id, command2) def test_CommandPlan_str(): assert str(commandPlan1) == str(command1) def test_CommandPlan_repr(): assert eval(repr(commandPlan1)) == commandPlan1 commandEdit1 = CommandEdit(42, command2, 1234, 1235) def test_CommandEdit_init(): with pytest.raises(AssertionError): CommandEdit(42, command1, base_revision=1235, revision=1234) def test_CommandEdit_eq_same(): assert commandEdit1 == commandEdit1 def test_CommandEdit_eq_equal(): assert commandEdit1 == CommandEdit(42, command2, 1234, 1235) def test_CommandEdit_eq_different_id(): assert commandEdit1 != CommandEdit(43, commandEdit1.command, commandEdit1.base_revision, commandEdit1.revision) def test_CommandEdit_eq_different_command(): assert commandEdit1 != CommandEdit(commandEdit1.id, command1, commandEdit1.base_revision, commandEdit1.revision) def test_CommandEdit_eq_different_base_revisoin(): assert commandEdit1 != CommandEdit(commandEdit1.id, commandEdit1.command, 1233, commandEdit1.revision) def test_CommandEdit_eq_different_revision(): assert commandEdit1 != CommandEdit(commandEdit1.id, commandEdit1.command, commandEdit1.base_revision, 1236) def test_CommandEdit_str(): assert str(commandEdit1) == '# ' + str(command2) def test_CommandEdit_repr(): assert eval(repr(commandEdit1)) == commandEdit1 commandNoop1 = CommandNoop(42, command2, 1234) def test_CommandNoop_eq_same(): assert commandNoop1 == commandNoop1 def test_CommandNoop_eq_equal(): assert commandNoop1 == CommandNoop(42, command2, 1234) def test_CommandNoop_eq_different_id(): assert commandNoop1 != CommandNoop(43, commandNoop1.command, commandNoop1.revision) def test_CommandNoop_eq_different_command(): assert commandNoop1 != CommandNoop(commandNoop1.id, command1, commandNoop1.revision) def test_CommandNoop_eq_different_revision(): assert commandNoop1 != CommandNoop(commandNoop1.id, commandNoop1.command, 1235) def test_CommandNoop_str(): assert str(commandNoop1) == '# ' + str(command2) def test_CommandNoop_repr(): assert eval(repr(commandNoop1)) == commandNoop1 commandWithMissingPage = Command('Page that definitely does not exist', command2.actions) commandPageMissing1 = CommandPageMissing(42, commandWithMissingPage, '2019-03-11T23:26:02Z') def test_CommandPageMissing_can_retry_immediately(): assert not commandPageMissing1.can_retry_immediately() def test_CommandPageMissing_can_continue_batch(): assert commandPageMissing1.can_continue_batch() def test_CommandPageMissing_eq_same(): assert commandPageMissing1 == commandPageMissing1 def test_CommandPageMissing_eq_equal(): assert commandPageMissing1 == CommandPageMissing(42, commandWithMissingPage, '2019-03-11T23:26:02Z') def test_CommandPageMissing_eq_different_id(): assert commandPageMissing1 != CommandPageMissing(43, commandPageMissing1.command, commandPageMissing1.curtimestamp) def test_CommandPageMissing_eq_different_command(): assert commandPageMissing1 != CommandPageMissing(commandPageMissing1.id, command2, commandPageMissing1.curtimestamp) def test_CommandPageMissing_eq_different_curtimestamp(): assert commandPageMissing1 != CommandPageMissing(commandPageMissing1.id, commandPageMissing1.command, '2019-03-11T23:28:12Z') def test_CommandPageMissing_str(): assert str(commandPageMissing1) == '# ' + str(commandWithMissingPage) def test_CommandPageMissing_repr(): assert eval(repr(commandPageMissing1)) == commandPageMissing1 commandEditConflict1 = CommandEditConflict(42, command1) def test_CommandEditConflict_can_retry_immediately(): assert commandEditConflict1.can_retry_immediately() def test_CommandEditConflict_can_continue_batch(): assert commandEditConflict1.can_continue_batch() def test_CommandEditConflict_eq_same(): assert commandEditConflict1 == commandEditConflict1 def test_CommandEditConflict_eq_equal(): assert commandEditConflict1 == CommandEditConflict(42, command1) def test_CommandEditConflict_eq_different_id(): assert commandEditConflict1 != CommandEditConflict(43, commandEditConflict1.command) def test_CommandEditConflict_eq_different_command(): assert commandEditConflict1 != CommandEditConflict(commandEditConflict1.id, command2) def test_CommandEditConflict_str(): assert str(commandEditConflict1) == '# ' + str(command1) def test_CommandEditConflict_repr(): assert eval(repr(commandEditConflict1)) == commandEditConflict1 -commandMaxlagExceeded1 = CommandMaxlagExceeded(42, command1, datetime.datetime(2019, 3, 16, 15, 24, 2, 607831, tzinfo=datetime.timezone.utc)) +commandMaxlagExceeded1 = CommandMaxlagExceeded(42, command1, datetime.datetime(2019, 3, 16, 15, 24, 2, tzinfo=datetime.timezone.utc)) def test_CommandMaxlagExceeded_can_retry_immediately(): assert not commandMaxlagExceeded1.can_retry_immediately() def test_CommandMaxlagExceeded_can_continue_batch(): assert not commandMaxlagExceeded1.can_continue_batch() def test_CommandMaxlagExceeded_eq_same(): assert commandMaxlagExceeded1 == commandMaxlagExceeded1 def test_CommandMaxlagExceeded_eq_equal(): - assert commandMaxlagExceeded1 == CommandMaxlagExceeded(42, command1, datetime.datetime(2019, 3, 16, 15, 24, 2, 607831, tzinfo=datetime.timezone.utc)) + assert commandMaxlagExceeded1 == CommandMaxlagExceeded(42, command1, datetime.datetime(2019, 3, 16, 15, 24, 2, tzinfo=datetime.timezone.utc)) def test_CommandMaxlagExceeded_eq_different_id(): assert commandMaxlagExceeded1 != CommandMaxlagExceeded(43, commandMaxlagExceeded1.command, commandMaxlagExceeded1.retry_after) def test_CommandMaxlagExceeded_eq_different_command(): assert commandMaxlagExceeded1 != CommandMaxlagExceeded(commandMaxlagExceeded1.id, command2, commandMaxlagExceeded1.retry_after) def test_CommandMaxlagExceeded_eq_different_retry_after(): - assert commandMaxlagExceeded1 != CommandMaxlagExceeded(commandMaxlagExceeded1.id, commandMaxlagExceeded1.command, datetime.datetime(2019, 3, 16, 15, 24, 2, 607831, tzinfo=datetime.timezone.max)) + assert commandMaxlagExceeded1 != CommandMaxlagExceeded(commandMaxlagExceeded1.id, commandMaxlagExceeded1.command, datetime.datetime(2019, 3, 16, 15, 24, 2, tzinfo=datetime.timezone.max)) def test_CommandMaxlagExceeded_str(): assert str(commandMaxlagExceeded1) == '# ' + str(command1) def test_CommandMaxlagExceeded_repr(): assert eval(repr(commandMaxlagExceeded1)) == commandMaxlagExceeded1 blockinfo = { 'blockedby': 'Lucas Werkmeister', 'blockedbyid': 1, 'blockedtimestamp': '2019-03-16T17:44:22Z', 'blockexpiry': 'indefinite', 'blockid': 1, 'blockpartial': False, 'blockreason': 'my custom reason', } commandBlocked1 = CommandBlocked(42, command1, False, blockinfo) commandBlocked2 = CommandBlocked(42, command1, False, None) def test_CommandBlocked_can_retry_immediately(): assert not commandBlocked1.can_retry_immediately() def test_CommandBlocked_can_continue_batch(): assert not commandBlocked1.can_continue_batch() def test_CommandBlocked_eq_same(): assert commandBlocked1 == commandBlocked1 def test_CommandBlocked_eq_equal(): assert commandBlocked1 == CommandBlocked(42, command1, False, blockinfo) def test_CommandBlocked_eq_different_id(): assert commandBlocked1 != CommandBlocked(43, commandBlocked1.command, commandBlocked1.auto, commandBlocked1.blockinfo) def test_CommandBlocked_eq_different_command(): assert commandBlocked1 != CommandBlocked(commandBlocked1.id, command2, commandBlocked1.auto, commandBlocked1.blockinfo) def test_CommandBlocked_eq_different_auto(): assert commandBlocked1 != CommandBlocked(commandBlocked1.id, commandBlocked1.command, True, blockinfo) def test_CommandBlocked_eq_different_blockinfo(): assert commandBlocked1 != CommandBlocked(commandBlocked1.id, commandBlocked1.command, commandBlocked1.auto, None) def test_CommandBlocked_str(): assert str(commandBlocked1) == '# ' + str(command1) def test_CommandBlocked_repr(): assert eval(repr(commandBlocked1)) == commandBlocked1 commandWikiReadOnly1 = CommandWikiReadOnly(42, command1, 'maintenance') commandWikiReadOnly2 = CommandWikiReadOnly(42, command1, None) def test_CommandWikiReadOnly_can_retry_immediately(): assert not commandWikiReadOnly1.can_retry_immediately() def test_CommandWikiReadOnly_can_continue_batch(): assert not commandWikiReadOnly1.can_continue_batch() def test_CommandWikiReadOnly_eq_same(): assert commandWikiReadOnly1 == commandWikiReadOnly1 def test_CommandWikiReadOnly_eq_equal(): assert commandWikiReadOnly1 == CommandWikiReadOnly(42, command1, 'maintenance') def test_CommandWikiReadOnly_eq_different_id(): assert commandWikiReadOnly1 != CommandWikiReadOnly(43, commandWikiReadOnly1.command, commandWikiReadOnly1.reason) def test_CommandWikiReadOnly_eq_different_command(): assert commandWikiReadOnly1 != CommandWikiReadOnly(commandWikiReadOnly1.id, command2, commandWikiReadOnly1.reason) def test_CommandWikiReadOnly_eq_different_reason(): assert commandWikiReadOnly1 != CommandWikiReadOnly(commandWikiReadOnly1.id, commandWikiReadOnly1.command, None) def test_CommandWikiReadOnly_str(): assert str(commandWikiReadOnly1) == '# ' + str(command1) def test_CommandWikiReadOnly_repr(): assert eval(repr(commandWikiReadOnly1)) == commandWikiReadOnly1 diff --git a/test_store.py b/test_store.py index 3ff1914..8a8a2e5 100644 --- a/test_store.py +++ b/test_store.py @@ -1,271 +1,272 @@ import contextlib import datetime import json import os import pymysql import pytest import random import string from command import CommandEdit, CommandNoop from store import InMemoryStore, DatabaseStore, _DatabaseCommandRecords, _StringTableStore from test_batch import newBatch1 from test_command import commandPlan1, commandEdit1, commandNoop1, commandPageMissing1, commandEditConflict1, commandMaxlagExceeded1, commandBlocked1, blockinfo, commandBlocked2, commandWikiReadOnly1, commandWikiReadOnly2 from test_utils import FakeSession fake_session = FakeSession({ 'query': { 'userinfo': { 'id': 6198807, 'name': 'Lucas Werkmeister', 'centralids': { 'CentralAuth': 46054761, 'local': 6198807 }, 'attachedlocal': { 'CentralAuth': '', 'local': '' } } } }) fake_session.host = 'https://commons.wikimedia.org' def test_InMemoryStore_store_batch_command_ids(): open_batch = InMemoryStore().store_batch(newBatch1, fake_session) assert len(open_batch.command_records) == 2 assert open_batch.command_records[0].id != open_batch.command_records[1].id def test_InMemoryStore_store_batch_batch_ids(): store = InMemoryStore() open_batch_1 = store.store_batch(newBatch1, fake_session) open_batch_2 = store.store_batch(newBatch1, fake_session) assert open_batch_1.id != open_batch_2.id def test_InMemoryStore_store_batch_metadata(): open_batch = InMemoryStore().store_batch(newBatch1, fake_session) assert open_batch.user_name == 'Lucas Werkmeister' assert open_batch.local_user_id == 6198807 assert open_batch.global_user_id == 46054761 assert open_batch.domain == 'commons.wikimedia.org' def test_InMemoryStore_get_batch(): store = InMemoryStore() open_batch = store.store_batch(newBatch1, fake_session) assert open_batch is store.get_batch(open_batch.id) def test_InMemoryStore_get_batch_None(): assert InMemoryStore().get_batch(0) is None def test_InMemoryStore_get_latest_batches(): store = InMemoryStore() open_batches = [] for i in range(25): open_batches.append(store.store_batch(newBatch1, fake_session)) open_batches.reverse() assert open_batches[:10] == store.get_latest_batches() @contextlib.contextmanager def temporary_database(): if 'MARIADB_ROOT_PASSWORD' not in os.environ: pytest.skip('MariaDB credentials not provided') connection = pymysql.connect(host='localhost', user='root', password=os.environ['MARIADB_ROOT_PASSWORD']) database_name = 'quickcategories_test_' + ''.join(random.choice(string.ascii_lowercase + string.digits) for i in range(16)) user_name = 'quickcategories_test_user_' + ''.join(random.choice(string.ascii_lowercase + string.digits) for i in range(16)) user_password = 'quickcategories_test_password_' + ''.join(random.choice(string.ascii_lowercase + string.digits) for i in range(16)) try: with connection.cursor() as cursor: cursor.execute('CREATE DATABASE `%s`' % database_name) cursor.execute('GRANT ALL PRIVILEGES ON `%s`.* TO `%s` IDENTIFIED BY %%s' % (database_name, user_name), (user_password,)) cursor.execute('USE `%s`' % database_name) with open('tables.sql') as tables: queries = tables.read() # PyMySQL does not support multiple queries in execute(), so we have to split for query in queries.split(';'): query = query.strip() if query: cursor.execute(query) cursor.execute('USE mysql') connection.commit() yield {'host': 'localhost', 'user': user_name, 'password': user_password, 'db': database_name} finally: with connection.cursor() as cursor: cursor.execute('DROP DATABASE IF EXISTS `%s`' % database_name) cursor.execute('DROP USER IF EXISTS `%s`' % user_name) connection.commit() connection.close() def test_DatabaseStore_store_batch(): with temporary_database() as connection_params: store = DatabaseStore(connection_params) open_batch = store.store_batch(newBatch1, fake_session) command2 = open_batch.command_records[1] with store._connect() as connection: with connection.cursor() as cursor: cursor.execute('SELECT * FROM `batch`') with connection.cursor() as cursor: cursor.execute('SELECT * FROM `command`') with connection.cursor() as cursor: cursor.execute('SELECT `command_page`, `actions_tpsv` FROM `command` JOIN `actions` on `command_actions_id` = `actions_id` WHERE `command_id` = %s AND `command_batch` = %s', (command2.id, open_batch.id)) command2_page, command2_actions_tpsv = cursor.fetchone() assert command2_page == command2.command.page assert command2_actions_tpsv == command2.command.actions_tpsv() def test_DatabaseStore_get_batch(): with temporary_database() as connection_params: store = DatabaseStore(connection_params) stored_batch = store.store_batch(newBatch1, fake_session) loaded_batch = store.get_batch(stored_batch.id) assert loaded_batch.id == stored_batch.id assert loaded_batch.user_name == 'Lucas Werkmeister' assert loaded_batch.local_user_id == 6198807 assert loaded_batch.global_user_id == 46054761 assert loaded_batch.domain == 'commons.wikimedia.org' assert len(loaded_batch.command_records) == 2 assert loaded_batch.command_records[0] == stored_batch.command_records[0] assert loaded_batch.command_records[1] == stored_batch.command_records[1] assert loaded_batch.command_records[0:2] == stored_batch.command_records[0:2] def test_DatabaseStore_get_batch_missing(): with temporary_database() as connection_params: store = DatabaseStore(connection_params) loaded_batch = store.get_batch(1) assert loaded_batch is None def test_DatabaseStore_update_batch(): with temporary_database() as connection_params: store = DatabaseStore(connection_params) stored_batch = store.store_batch(newBatch1, fake_session) loaded_batch = store.get_batch(stored_batch.id) [command_plan_1, command_plan_2] = loaded_batch.command_records[0:2] command_edit = CommandEdit(command_plan_1.id, command_plan_1.command, 1234, 1235) loaded_batch.command_records[0] = command_edit command_edit_loaded = loaded_batch.command_records[0] assert command_edit == command_edit_loaded command_noop = CommandNoop(command_plan_1.id, command_plan_1.command, 1234) loaded_batch.command_records[1] = command_noop command_noop_loaded = loaded_batch.command_records[0] assert command_noop == command_noop_loaded assert stored_batch.command_records[0:2] == loaded_batch.command_records[0:2] # TODO ideally, the timestamps on stored_batch and loaded_batch would update as well reloaded_batch = store.get_batch(stored_batch.id) assert reloaded_batch.last_updated > reloaded_batch.created def test_DatabaseStore_get_latest_batches(): with temporary_database() as connection_params: store = DatabaseStore(connection_params) open_batches = [] for i in range(25): open_batches.append(store.store_batch(newBatch1, fake_session)) open_batches.reverse() assert open_batches[:10] == store.get_latest_batches() def test_DatabaseStore_datetime_to_utc_timestamp(): store = DatabaseStore({}) - dt = datetime.datetime(2019, 3, 17, 13, 23, 28, 251638, tzinfo=datetime.timezone.utc) - assert store._datetime_to_utc_timestamp(dt) == 1552829008.251638 + dt = datetime.datetime(2019, 3, 17, 13, 23, 28, tzinfo=datetime.timezone.utc) + assert store._datetime_to_utc_timestamp(dt) == 1552829008 @pytest.mark.parametrize('dt', [ datetime.datetime.now(), datetime.datetime.utcnow(), datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=1))), + datetime.datetime(2019, 3, 17, 13, 23, 28, 251638, tzinfo=datetime.timezone.utc) ]) def test_DatabaseStore_datetime_to_utc_timestamp_invalid_timezone(dt): store = DatabaseStore({}) with pytest.raises(AssertionError): store._datetime_to_utc_timestamp(dt) def test_DatabaseStore_utc_timestamp_to_datetime(): store = DatabaseStore({}) - dt = datetime.datetime(2019, 3, 17, 13, 23, 28, 251638, tzinfo=datetime.timezone.utc) - assert store._utc_timestamp_to_datetime(1552829008.251638) == dt + dt = datetime.datetime(2019, 3, 17, 13, 23, 28, tzinfo=datetime.timezone.utc) + assert store._utc_timestamp_to_datetime(1552829008) == dt command_records_and_rows = [ # (commandPlan1, (DatabaseStore._COMMAND_STATUS_PLAN, None)), # not supported for update, but perhaps turn into test for initial store? (commandEdit1, (DatabaseStore._COMMAND_STATUS_EDIT, {'base_revision': 1234, 'revision': 1235})), (commandNoop1, (DatabaseStore._COMMAND_STATUS_NOOP, {'revision': 1234})), (commandPageMissing1, (DatabaseStore._COMMAND_STATUS_PAGE_MISSING, {'curtimestamp': '2019-03-11T23:26:02Z'})), (commandEditConflict1, (DatabaseStore._COMMAND_STATUS_EDIT_CONFLICT, {})), - (commandMaxlagExceeded1, (DatabaseStore._COMMAND_STATUS_MAXLAG_EXCEEDED, {'retry_after_utc_timestamp': 1552749842.607831})), + (commandMaxlagExceeded1, (DatabaseStore._COMMAND_STATUS_MAXLAG_EXCEEDED, {'retry_after_utc_timestamp': 1552749842})), (commandBlocked1, (DatabaseStore._COMMAND_STATUS_BLOCKED, {'auto': False, 'blockinfo': blockinfo})), (commandBlocked2, (DatabaseStore._COMMAND_STATUS_BLOCKED, {'auto': False, 'blockinfo': None})), (commandWikiReadOnly1, (DatabaseStore._COMMAND_STATUS_WIKI_READ_ONLY, {'reason': 'maintenance'})), (commandWikiReadOnly2, (DatabaseStore._COMMAND_STATUS_WIKI_READ_ONLY, {'reason': None})), ] @pytest.mark.parametrize('command_record, expected_row', command_records_and_rows) def test_DatabaseCommandRecords_command_record_to_row(command_record, expected_row): actual_row = _DatabaseCommandRecords(0, DatabaseStore({}))._command_record_to_row(command_record) assert expected_row == actual_row @pytest.mark.parametrize('expected_command_record, row', command_records_and_rows) def test_DatabaseCommandRecords_row_to_command_record(expected_command_record, row): status, outcome = row full_row = expected_command_record.id, expected_command_record.command.page, expected_command_record.command.actions_tpsv(), status, json.dumps(outcome) actual_command_record = _DatabaseCommandRecords(0, DatabaseStore({}))._row_to_command_record(*full_row) assert expected_command_record == actual_command_record @pytest.mark.parametrize('string, expected_hash', [ # all hashes obtained in MariaDB via SELECT CAST(CONV(SUBSTRING(SHA2(**string**, 256), 1, 8), 16, 10) AS unsigned int); ('', 3820012610), ('test.wikipedia.org', 3277830609), ('äöü', 3157433791), ('☺', 3752208785), ('🤔', 1622577385), ]) def test_StringTableStore_hash(string, expected_hash): store = _StringTableStore('', '', '', '') actual_hash = store._hash(string) assert expected_hash == actual_hash def test_StringTableStore_acquire_id_database(): with temporary_database() as connection_params: connection = pymysql.connect(**connection_params) try: store = _StringTableStore('domain', 'domain_id', 'domain_hash', 'domain_name') store.acquire_id(connection, 'test.wikipedia.org') with connection.cursor() as cursor: cursor.execute('SELECT domain_name FROM domain WHERE domain_hash = 3277830609') result = cursor.fetchone() assert result == ('test.wikipedia.org',) with store._cache_lock: store._cache.clear() store.acquire_id(connection, 'test.wikipedia.org') with connection.cursor() as cursor: cursor.execute('SELECT COUNT(*) FROM domain') result = cursor.fetchone() assert result == (1,) finally: connection.close() def test_StringTableStore_acquire_id_cached(): store = _StringTableStore('', '', '', '') with store._cache_lock: store._cache['test.wikipedia.org'] = 1 assert store.acquire_id(None, 'test.wikipedia.org') == 1