diff --git a/app.py b/app.py index e3697e6..9377016 100644 --- a/app.py +++ b/app.py @@ -1,353 +1,423 @@ # -*- coding: utf-8 -*- import datetime import flask import humanize import mwapi # type: ignore import mwoauth # type: ignore import os import random import re import requests_oauthlib # type: ignore import string import toolforge from typing import List, Optional, Tuple import yaml -from batch import StoredBatch +from batch import StoredBatch, OpenBatch from command import Command, CommandRecord, CommandPlan, CommandPending, CommandEdit, CommandNoop, CommandFailure, CommandPageMissing, CommandEditConflict, CommandMaxlagExceeded, CommandBlocked, CommandWikiReadOnly import parse_tpsv from runner import Runner import store app = flask.Flask(__name__) user_agent = toolforge.set_user_agent('quickcategories', email='mail@lucaswerkmeister.de') __dir__ = os.path.dirname(__file__) try: with open(os.path.join(__dir__, 'config.yaml')) as config_file: app.config.update(yaml.safe_load(config_file)) except FileNotFoundError: print('config.yaml file not found, assuming local development setup') app.secret_key = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(64)) if 'oauth' in app.config: consumer_token = mwoauth.ConsumerToken(app.config['oauth']['consumer_key'], app.config['oauth']['consumer_secret']) if 'database' in app.config: batch_store = store.DatabaseStore(app.config['database']) # type: store.BatchStore else: print('No database configuration, using in-memory store (batches will be lost on every restart)') batch_store = store.InMemoryStore() @app.template_global() def csrf_token() -> str: if 'csrf_token' not in flask.session: flask.session['csrf_token'] = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(64)) return flask.session['csrf_token'] @app.template_global() def form_value(name: str) -> flask.Markup: if 'repeat_form' in flask.g and name in flask.request.form: return (flask.Markup(r' value="') + flask.Markup.escape(flask.request.form[name]) + flask.Markup(r'" ')) else: return flask.Markup() @app.template_global() def form_attributes(name: str) -> flask.Markup: return (flask.Markup(r' id="') + flask.Markup.escape(name) + flask.Markup(r'" name="') + flask.Markup.escape(name) + flask.Markup(r'" ') + form_value(name)) @app.template_filter() def user_link(user_name: str) -> flask.Markup: return (flask.Markup(r'') + flask.Markup(r'') + flask.Markup.escape(user_name) + flask.Markup(r'') + flask.Markup(r'')) @app.template_global() def user_logged_in() -> bool: return authenticated_session() is not None @app.template_global() def authentication_area() -> flask.Markup: if 'oauth' not in app.config: return flask.Markup() if 'oauth_access_token' not in flask.session: return (flask.Markup(r'Log in')) access_token = mwoauth.AccessToken(**flask.session['oauth_access_token']) identity = mwoauth.identify('https://meta.wikimedia.org/w/index.php', consumer_token, access_token) return (flask.Markup(r'Logged in as ') + user_link(identity['username']) + flask.Markup(r'')) @app.template_global() def can_run_commands(command_records: List[CommandRecord]) -> bool: return flask.g.can_run_commands and any(filter(lambda command_record: isinstance(command_record, CommandPlan), command_records)) +@app.template_global() +def can_start_background() -> bool: + return flask.g.can_start_background + +@app.template_global() +def can_stop_background() -> bool: + return flask.g.can_stop_background + @app.template_global() # TODO make domain part of Command and turn this into a template filter? def render_command(command: Command, domain: str) -> flask.Markup: return flask.Markup(flask.render_template('command.html', domain=domain, command=command)) @app.template_global() # TODO also turn into a template filter? def render_command_record(command_record: CommandRecord, domain: str) -> flask.Markup: if isinstance(command_record, CommandPlan): command_record_markup = flask.render_template('command_plan.html', domain=domain, command_plan=command_record) elif isinstance(command_record, CommandPending): command_record_markup = flask.render_template('command_pending.html', domain=domain, command_pending=command_record) elif isinstance(command_record, CommandEdit): command_record_markup = flask.render_template('command_edit.html', domain=domain, command_edit=command_record) elif isinstance(command_record, CommandNoop): command_record_markup = flask.render_template('command_noop.html', domain=domain, command_noop=command_record) elif isinstance(command_record, CommandPageMissing): command_record_markup = flask.render_template('command_page_missing.html', domain=domain, command_page_missing=command_record) elif isinstance(command_record, CommandEditConflict): command_record_markup = flask.render_template('command_edit_conflict.html', domain=domain, command_edit_conflict=command_record) elif isinstance(command_record, CommandMaxlagExceeded): command_record_markup = flask.render_template('command_maxlag_exceeded.html', domain=domain, command_maxlag_exceeded=command_record) elif isinstance(command_record, CommandBlocked): command_record_markup = flask.render_template('command_blocked.html', domain=domain, command_blocked=command_record) elif isinstance(command_record, CommandWikiReadOnly): command_record_markup = flask.render_template('command_wiki_read_only.html', domain=domain, command_blocked=command_record) else: raise ValueError('Unknown command record type') return flask.Markup(command_record_markup) @app.template_filter() def render_datetime(dt: datetime.datetime) -> flask.Markup: naive_dt = dt.astimezone().replace(tzinfo=None) # humanize doesn’t support timezones :( return (flask.Markup(r'') + flask.Markup.escape(humanize.naturaltime(naive_dt)) + flask.Markup(r'')) @app.template_global() def render_batch_user(batch: StoredBatch) -> flask.Markup: return (flask.Markup(r'') + flask.Markup.escape(batch.user_name) + flask.Markup(r'')) def authenticated_session(domain: str = 'meta.wikimedia.org') -> Optional[mwapi.Session]: if 'oauth_access_token' in flask.session: access_token = mwoauth.AccessToken(**flask.session['oauth_access_token']) auth = requests_oauthlib.OAuth1(client_key=consumer_token.key, client_secret=consumer_token.secret, resource_owner_key=access_token.key, resource_owner_secret=access_token.secret) return mwapi.Session(host='https://'+domain, auth=auth, user_agent=user_agent) else: return None @app.route('/') def index(): return flask.render_template('index.html', latest_batches=batch_store.get_latest_batches()) @app.route('/batch', methods=['POST']) def new_batch(): if not submitted_request_valid(): return 'CSRF error', 400 domain = flask.request.form.get('domain', '(not provided)') if not is_wikimedia_domain(domain): return flask.Markup.escape(domain) + flask.Markup(' is not recognized as a Wikimedia domain'), 400 session = authenticated_session(domain) if not session: return 'not logged in', 403 # Forbidden; 401 Unauthorized would be inappropriate because we don’t send WWW-Authenticate try: batch = parse_tpsv.parse_batch(flask.request.form.get('commands', '')) except parse_tpsv.ParseBatchError as e: return str(e) batch.cleanup() id = batch_store.store_batch(batch, session).id return flask.redirect(flask.url_for('batch', id=id)) @app.route('/batch//') def batch(id: int): batch = batch_store.get_batch(id) if batch is None: return flask.render_template('batch_not_found.html', id=id), 404 session = authenticated_session(batch.domain) if session: - local_user_id = session.get(action='query', - meta='userinfo')['query']['userinfo']['id'] + userinfo = session.get(action='query', + meta='userinfo', + uiprop=['groups'])['query']['userinfo'] + local_user_id = userinfo['id'] flask.g.can_run_commands = local_user_id == batch.local_user_id + flask.g.can_start_background = flask.g.can_run_commands and \ + 'autoconfirmed' in userinfo['groups'] + flask.g.can_stop_background = flask.g.can_start_background or \ + 'sysop' in userinfo['groups'] else: flask.g.can_run_commands = False + flask.g.can_start_background = False + flask.g.can_stop_background = False offset, limit = slice_from_args(flask.request.args) return flask.render_template('batch.html', batch=batch, offset=offset, limit=limit) @app.route('/batch//run_slice', methods=['POST']) def run_batch_slice(id: int): batch = batch_store.get_batch(id) if batch is None: return flask.render_template('batch_not_found.html', id=id), 404 session = authenticated_session(batch.domain) if not session: return 'not logged in', 403 local_user_id = session.get(action='query', meta='userinfo')['query']['userinfo']['id'] if local_user_id != batch.local_user_id: return 'may not run this batch', 403 if 'summary_suffix' in app.config: summary_suffix = app.config['summary_suffix'].format(id) else: summary_suffix = None runner = Runner(session, summary_suffix) offset, limit = slice_from_args(flask.request.form) command_pendings = batch.command_records.make_plans_pending(offset, limit) runner.prepare_pages([command_pending.command.page for command_pending in command_pendings]) for command_pending in command_pendings: for attempt in range(5): command_finish = runner.run_command(command_pending) if isinstance(command_finish, CommandFailure) and command_finish.can_retry_immediately(): continue else: break batch.command_records.store_finish(command_finish) if isinstance(command_finish, CommandFailure) and not command_finish.can_continue_batch(): break return flask.redirect(flask.url_for('batch', id=id, offset=offset, limit=limit)) +@app.route('/batch//start_background', methods=['POST']) +def start_batch_background(id: int): + batch = batch_store.get_batch(id) + if batch is None: + return flask.render_template('batch_not_found.html', + id=id), 404 + if not isinstance(batch, OpenBatch): + return 'not an open batch', 400 + + session = authenticated_session(batch.domain) + if not session: + return 'not logged in', 403 + userinfo = session.get(action='query', + meta='userinfo', + uiprop=['groups'])['query']['userinfo'] + local_user_id = userinfo['id'] + if local_user_id != batch.local_user_id or \ + 'autoconfirmed' not in userinfo['groups']: + return 'may not start this batch in background', 403 + + batch_store.start_background(batch, session) + + offset, limit = slice_from_args(flask.request.form) + return flask.redirect(flask.url_for('batch', + id=id, + offset=offset, + limit=limit)) + +@app.route('/batch//stop_background', methods=['POST']) +def stop_batch_background(id: int): + batch = batch_store.get_batch(id) + if batch is None: + return flask.render_template('batch_not_found.html', + id=id), 404 + + session = authenticated_session(batch.domain) + if not session: + return 'not logged in', 403 + userinfo = session.get(action='query', + meta='userinfo', + uiprop=['groups'])['query']['userinfo'] + local_user_id = userinfo['id'] + if local_user_id != batch.local_user_id and \ + 'sysop' not in userinfo['groups']: + return 'may not stop this batch in background', 403 + + batch_store.stop_background(batch, session) + + offset, limit = slice_from_args(flask.request.form) + return flask.redirect(flask.url_for('batch', + id=id, + offset=offset, + limit=limit)) + @app.route('/login') def login(): redirect, request_token = mwoauth.initiate('https://meta.wikimedia.org/w/index.php', consumer_token, user_agent=user_agent) flask.session['oauth_request_token'] = dict(zip(request_token._fields, request_token)) return flask.redirect(redirect) @app.route('/oauth/callback') def oauth_callback(): request_token = mwoauth.RequestToken(**flask.session.pop('oauth_request_token')) access_token = mwoauth.complete('https://meta.wikimedia.org/w/index.php', consumer_token, request_token, flask.request.query_string, user_agent=user_agent) flask.session['oauth_access_token'] = dict(zip(access_token._fields, access_token)) return flask.redirect(flask.url_for('index')) def is_wikimedia_domain(domain: str) -> bool: return re.fullmatch(r'[a-z0-9-]+\.(?:wiki(?:pedia|media|books|data|news|quote|source|versity|voyage)|mediawiki|wiktionary)\.org', domain) is not None def slice_from_args(args: dict) -> Tuple[int, int]: try: offset = int(args['offset']) except (KeyError, ValueError): offset = 0 offset = max(0, offset) try: limit = int(args['limit']) except (KeyError, ValueError): limit = 50 limit = max(1, min(500, limit)) return offset, limit def full_url(endpoint: str, **kwargs) -> str: scheme = flask.request.headers.get('X-Forwarded-Proto', 'http') return flask.url_for(endpoint, _external=True, _scheme=scheme, **kwargs) def submitted_request_valid() -> bool: """Check whether a submitted POST request is valid. If this method returns False, the request might have been issued by an attacker as part of a Cross-Site Request Forgery attack; callers MUST NOT process the request in that case. """ real_token = flask.session.pop('csrf_token', None) submitted_token = flask.request.form.get('csrf_token', None) if not real_token: # we never expected a POST return False if not submitted_token: # token got lost or attacker did not supply it return False if submitted_token != real_token: # incorrect token (could be outdated or incorrectly forged) return False if not (flask.request.referrer or '').startswith(full_url('index')): # correct token but not coming from the correct page; for # example, JS running on https://tools.wmflabs.org/tool-a is # allowed to access https://tools.wmflabs.org/tool-b and # extract CSRF tokens from it (since both of these pages are # hosted on the https://tools.wmflabs.org domain), so checking # the Referer header is our only protection against attackers # from other Toolforge tools return False return True @app.after_request def deny_frame(response: flask.Response) -> flask.Response: """Disallow embedding the tool’s pages in other websites. If other websites can embed this tool’s pages, e. g. in s, other tools hosted on tools.wmflabs.org can send arbitrary web requests from this tool’s context, bypassing the referrer-based CSRF protection. """ response.headers['X-Frame-Options'] = 'deny' return response diff --git a/maintenance/0006-add-background.sql b/maintenance/0006-add-background.sql new file mode 100644 index 0000000..968bc7d --- /dev/null +++ b/maintenance/0006-add-background.sql @@ -0,0 +1,23 @@ +-- Add the background table, which records when a batch can be run in the background. + +CREATE TABLE background ( + background_id int unsigned NOT NULL PRIMARY KEY AUTO_INCREMENT, + background_batch int unsigned NOT NULL, + background_auth text, + background_started_utc_timestamp int unsigned NOT NULL, + background_started_user_name varchar(255) binary NOT NULL, + background_started_local_user_id int unsigned NOT NULL, + background_started_global_user_id int unsigned NOT NULL, + background_stopped_utc_timestamp int unsigned, + background_stopped_user_name varchar(255) binary, + background_stopped_local_user_id int unsigned, + background_stopped_global_user_id int unsigned +) +CHARACTER SET = 'utf8mb4' +COLLATE = 'utf8mb4_bin'; + +-- index for finding the backgrounds of a batch, optionally limited to just the ones not yet stopped +CREATE INDEX background_batch_stopped ON background (background_batch, background_stopped_utc_timestamp); + +-- index for finding backgrounds not yet stopped, for any batch +CREATE INDEX background_stopped_batch ON background (background_stopped_utc_timestamp, background_batch); diff --git a/store.py b/store.py index ee4d2cd..14b712c 100644 --- a/store.py +++ b/store.py @@ -1,435 +1,504 @@ import cachetools import contextlib import datetime import hashlib import itertools import json import mwapi # type: ignore import operator import pymysql +import requests_oauthlib # type: ignore import threading from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, cast from batch import NewBatch, StoredBatch, OpenBatch, ClosedBatch, BatchCommandRecords, BatchCommandRecordsList from command import Command, CommandPlan, CommandPending, CommandRecord, CommandFinish, CommandEdit, CommandNoop, CommandPageMissing, CommandEditConflict, CommandMaxlagExceeded, CommandBlocked, CommandWikiReadOnly import parse_tpsv def _metadata_from_session(session: mwapi.Session) -> Tuple[str, int, int, str]: domain = session.host[len('https://'):] response = session.get(action='query', meta='userinfo', uiprop='centralids') user_name = response['query']['userinfo']['name'] local_user_id = response['query']['userinfo']['id'] global_user_id = response['query']['userinfo']['centralids']['CentralAuth'] return user_name, local_user_id, global_user_id, domain def _now() -> datetime.datetime: return datetime.datetime.now(tz=datetime.timezone.utc).replace(microsecond=0) class BatchStore: def store_batch(self, new_batch: NewBatch, session: mwapi.Session) -> OpenBatch: ... def get_batch(self, id: int) -> Optional[StoredBatch]: ... def get_latest_batches(self) -> Sequence[StoredBatch]: ... + def start_background(self, batch: OpenBatch, session: mwapi.Session) -> None: ... + + def stop_background(self, batch: StoredBatch, session: Optional[mwapi.Session] = None) -> None: ... + class InMemoryStore(BatchStore): def __init__(self): self.next_batch_id = 1 self.next_command_id = 1 self.batches = {} # type: Dict[int, StoredBatch] + self.started_backgrounds = {} # type: Dict[int, Tuple[datetime.datetime, mwapi.Session]] + self.stopped_backgrounds = {} # type: Dict[int, List[Tuple[datetime.datetime, mwapi.Session, datetime.datetime, Optional[mwapi.Session]]]] def store_batch(self, new_batch: NewBatch, session: mwapi.Session) -> OpenBatch: created = _now() user_name, local_user_id, global_user_id, domain = _metadata_from_session(session) command_plans = [] # type: List[CommandRecord] for command in new_batch.commands: command_plans.append(CommandPlan(self.next_command_id, command)) self.next_command_id += 1 open_batch = OpenBatch(self.next_batch_id, user_name, local_user_id, global_user_id, domain, created, created, BatchCommandRecordsList(command_plans)) self.next_batch_id += 1 self.batches[open_batch.id] = open_batch return open_batch def get_batch(self, id: int) -> Optional[StoredBatch]: stored_batch = self.batches.get(id) if stored_batch is None: return None command_records = cast(BatchCommandRecordsList, stored_batch.command_records).command_records if isinstance(stored_batch, OpenBatch) and \ all(map(lambda command_record: isinstance(command_record, CommandFinish), command_records)): stored_batch = ClosedBatch(stored_batch.id, stored_batch.user_name, stored_batch.local_user_id, stored_batch.global_user_id, stored_batch.domain, stored_batch.created, stored_batch.last_updated, stored_batch.command_records) self.batches[id] = stored_batch return stored_batch def get_latest_batches(self) -> Sequence[StoredBatch]: return [cast(StoredBatch, self.get_batch(id)) for id in sorted(self.batches.keys(), reverse=True)[:10]] + def start_background(self, batch: OpenBatch, session: mwapi.Session) -> None: + started = _now() + if batch.id not in self.started_backgrounds: + self.started_backgrounds[batch.id] = (started, session) + + def stop_background(self, batch: StoredBatch, session: Optional[mwapi.Session] = None) -> None: + stopped = _now() + if batch.id in self.started_backgrounds: + started, started_session = self.started_backgrounds[batch.id] + stopped_backgrounds = self.stopped_backgrounds.get(batch.id, []) + stopped_backgrounds.append((started, started_session, stopped, session)) + self.stopped_backgrounds[batch.id] = stopped_backgrounds + del self.started_backgrounds[batch.id] + class DatabaseStore(BatchStore): _BATCH_STATUS_OPEN = 0 _BATCH_STATUS_CLOSED = 128 _COMMAND_STATUS_PLAN = 0 _COMMAND_STATUS_EDIT = 1 _COMMAND_STATUS_NOOP = 2 _COMMAND_STATUS_PENDING = 16 _COMMAND_STATUS_PAGE_MISSING = 129 _COMMAND_STATUS_EDIT_CONFLICT = 130 _COMMAND_STATUS_MAXLAG_EXCEEDED = 131 _COMMAND_STATUS_BLOCKED = 132 _COMMAND_STATUS_WIKI_READ_ONLY = 133 def __init__(self, connection_params: dict): connection_params.setdefault('charset', 'utf8mb4') self.connection_params = connection_params self.domain_store = _StringTableStore('domain', 'domain_id', 'domain_hash', 'domain_name') self.actions_store = _StringTableStore('actions', 'actions_id', 'actions_hash', 'actions_tpsv') @contextlib.contextmanager def _connect(self) -> Generator[pymysql.connections.Connection, None, None]: connection = pymysql.connect(**self.connection_params) try: yield connection finally: connection.close() def _datetime_to_utc_timestamp(self, dt: datetime.datetime) -> int: assert dt.tzinfo == datetime.timezone.utc assert dt.microsecond == 0 return int(dt.timestamp()) def _utc_timestamp_to_datetime(self, timestamp: int) -> datetime.datetime: return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) def store_batch(self, new_batch: NewBatch, session: mwapi.Session) -> OpenBatch: created = _now() created_utc_timestamp = self._datetime_to_utc_timestamp(created) user_name, local_user_id, global_user_id, domain = _metadata_from_session(session) with self._connect() as connection: domain_id = self.domain_store.acquire_id(connection, domain) with connection.cursor() as cursor: cursor.execute('INSERT INTO `batch` (`batch_user_name`, `batch_local_user_id`, `batch_global_user_id`, `batch_domain_id`, `batch_created_utc_timestamp`, `batch_last_updated_utc_timestamp`, `batch_status`) VALUES (%s, %s, %s, %s, %s, %s, %s)', (user_name, local_user_id, global_user_id, domain_id, created_utc_timestamp, created_utc_timestamp, DatabaseStore._BATCH_STATUS_OPEN)) batch_id = cursor.lastrowid with connection.cursor() as cursor: cursor.executemany('INSERT INTO `command` (`command_batch`, `command_page`, `command_actions_id`, `command_status`, `command_outcome`) VALUES (%s, %s, %s, %s, NULL)', [(batch_id, command.page, self.actions_store.acquire_id(connection, command.actions_tpsv()), DatabaseStore._COMMAND_STATUS_PLAN) for command in new_batch.commands]) connection.commit() return OpenBatch(batch_id, user_name, local_user_id, global_user_id, domain, created, created, _BatchCommandRecordsDatabase(batch_id, self)) def get_batch(self, id: int) -> Optional[StoredBatch]: with self._connect() as connection: with connection.cursor() as cursor: cursor.execute('''SELECT `batch_id`, `batch_user_name`, `batch_local_user_id`, `batch_global_user_id`, `domain_name`, `batch_created_utc_timestamp`, `batch_last_updated_utc_timestamp`, `batch_status` FROM `batch` JOIN `domain` ON `batch_domain_id` = `domain_id` WHERE `batch_id` = %s''', (id,)) result = cursor.fetchone() if not result: return None return self._result_to_batch(result) def _result_to_batch(self, result: tuple) -> StoredBatch: id, user_name, local_user_id, global_user_id, domain, created_utc_timestamp, last_updated_utc_timestamp, status = result created = self._utc_timestamp_to_datetime(created_utc_timestamp) last_updated = self._utc_timestamp_to_datetime(last_updated_utc_timestamp) if status == DatabaseStore._BATCH_STATUS_OPEN: return OpenBatch(id, user_name, local_user_id, global_user_id, domain, created, last_updated, _BatchCommandRecordsDatabase(id, self)) elif status == DatabaseStore._BATCH_STATUS_CLOSED: return ClosedBatch(id, user_name, local_user_id, global_user_id, domain, created, last_updated, _BatchCommandRecordsDatabase(id, self)) else: raise ValueError('Unknown batch type') def get_latest_batches(self) -> Sequence[StoredBatch]: with self._connect() as connection: with connection.cursor() as cursor: cursor.execute('''SELECT `batch_id`, `batch_user_name`, `batch_local_user_id`, `batch_global_user_id`, `domain_name`, `batch_created_utc_timestamp`, `batch_last_updated_utc_timestamp`, `batch_status` FROM `batch` JOIN `domain` ON `batch_domain_id` = `domain_id` ORDER BY `batch_id` DESC LIMIT 10''') return [self._result_to_batch(result) for result in cursor.fetchall()] + def start_background(self, batch: OpenBatch, session: mwapi.Session) -> None: + started = _now() + started_utc_timestamp = self._datetime_to_utc_timestamp(started) + user_name, local_user_id, global_user_id, domain = _metadata_from_session(session) + + with self._connect() as connection, connection.cursor() as cursor: + cursor.execute('''SELECT 1 + FROM `background` + WHERE `background_batch` = %s + AND `background_stopped_utc_timestamp` IS NULL + FOR UPDATE''', + (batch.id,)) + if cursor.fetchone(): + connection.commit() # finish the FOR UPDATE + return + + assert isinstance(session.session.auth, requests_oauthlib.OAuth1) + auth = {'resource_owner_key': session.session.auth.client.resource_owner_key, + 'resource_owner_secret': session.session.auth.client.resource_owner_secret} + + cursor.execute('''INSERT INTO `background` + (`background_batch`, `background_auth`, `background_started_utc_timestamp`, `background_started_user_name`, `background_started_local_user_id`, `background_started_global_user_id`) + VALUES (%s, %s, %s, %s, %s, %s)''', + (batch.id, json.dumps(auth), started_utc_timestamp, user_name, local_user_id, global_user_id)) + connection.commit() + + def stop_background(self, batch: StoredBatch, session: Optional[mwapi.Session] = None) -> None: + self._stop_background_by_id(batch.id, session) + + def _stop_background_by_id(self, batch_id: int, session: Optional[mwapi.Session] = None) -> None: + stopped = _now() + stopped_utc_timestamp = self._datetime_to_utc_timestamp(stopped) + if session: + user_name, local_user_id, global_user_id, domain = _metadata_from_session(session) # type: Tuple[Optional[str], Optional[int], Optional[int], str] + else: + user_name, local_user_id, global_user_id = None, None, None + with self._connect() as connection, connection.cursor() as cursor: + cursor.execute('''UPDATE `background` + SET `background_auth` = NULL, `background_stopped_utc_timestamp` = %s, `background_stopped_user_name` = %s, `background_stopped_local_user_id` = %s, `background_stopped_global_user_id` = %s + WHERE `background_batch` = %s + AND `background_stopped_utc_timestamp` IS NULL''', + (stopped_utc_timestamp, user_name, local_user_id, global_user_id, batch_id)) + connection.commit() + if cursor.rowcount > 1: + raise RuntimeError('Should have stopped at most 1 background operation, actually affected %d!' % cursor.rowcount) + class _BatchCommandRecordsDatabase(BatchCommandRecords): def __init__(self, batch_id: int, store: DatabaseStore): self.batch_id = batch_id self.store = store def _command_finish_to_row(self, command_finish: CommandFinish) -> Tuple[int, dict]: if isinstance(command_finish, CommandEdit): status = DatabaseStore._COMMAND_STATUS_EDIT outcome = {'base_revision': command_finish.base_revision, 'revision': command_finish.revision} # type: dict elif isinstance(command_finish, CommandNoop): status = DatabaseStore._COMMAND_STATUS_NOOP outcome = {'revision': command_finish.revision} elif isinstance(command_finish, CommandPageMissing): status = DatabaseStore._COMMAND_STATUS_PAGE_MISSING outcome = {'curtimestamp': command_finish.curtimestamp} elif isinstance(command_finish, CommandEditConflict): status = DatabaseStore._COMMAND_STATUS_EDIT_CONFLICT outcome = {} elif isinstance(command_finish, CommandMaxlagExceeded): status = DatabaseStore._COMMAND_STATUS_MAXLAG_EXCEEDED outcome = {'retry_after_utc_timestamp': self.store._datetime_to_utc_timestamp(command_finish.retry_after)} elif isinstance(command_finish, CommandBlocked): status = DatabaseStore._COMMAND_STATUS_BLOCKED outcome = {'auto': command_finish.auto, 'blockinfo': command_finish.blockinfo} elif isinstance(command_finish, CommandWikiReadOnly): status = DatabaseStore._COMMAND_STATUS_WIKI_READ_ONLY outcome = {'reason': command_finish.reason} else: raise ValueError('Unknown command type') return status, outcome def _row_to_command_record(self, id: int, page: str, actions_tpsv: str, status: int, outcome: Optional[str]) -> CommandRecord: if outcome: outcome_dict = json.loads(outcome) command = Command(page, [parse_tpsv.parse_action(field) for field in actions_tpsv.split('|')]) if status == DatabaseStore._COMMAND_STATUS_PLAN: assert outcome is None return CommandPlan(id, command) elif status == DatabaseStore._COMMAND_STATUS_EDIT: return CommandEdit(id, command, base_revision=outcome_dict['base_revision'], revision=outcome_dict['revision']) elif status == DatabaseStore._COMMAND_STATUS_NOOP: return CommandNoop(id, command, revision=outcome_dict['revision']) elif status == DatabaseStore._COMMAND_STATUS_PENDING: assert outcome is None return CommandPending(id, command) elif status == DatabaseStore._COMMAND_STATUS_PAGE_MISSING: return CommandPageMissing(id, command, curtimestamp=outcome_dict['curtimestamp']) elif status == DatabaseStore._COMMAND_STATUS_EDIT_CONFLICT: return CommandEditConflict(id, command) elif status == DatabaseStore._COMMAND_STATUS_MAXLAG_EXCEEDED: return CommandMaxlagExceeded(id, command, self.store._utc_timestamp_to_datetime(outcome_dict['retry_after_utc_timestamp'])) elif status == DatabaseStore._COMMAND_STATUS_BLOCKED: return CommandBlocked(id, command, auto=outcome_dict['auto'], blockinfo=outcome_dict['blockinfo']) elif status == DatabaseStore._COMMAND_STATUS_WIKI_READ_ONLY: return CommandWikiReadOnly(id, command, outcome_dict['reason']) else: raise ValueError('Unknown command status %d' % status) def get_slice(self, offset: int, limit: int) -> List[CommandRecord]: command_records = [] with self.store._connect() as connection, connection.cursor() as cursor: cursor.execute('''SELECT `command_id`, `command_page`, `actions_tpsv`, `command_status`, `command_outcome` FROM `command` JOIN `actions` ON `command_actions_id` = `actions_id` WHERE `command_batch` = %s ORDER BY `command_id` ASC LIMIT %s OFFSET %s''', (self.batch_id, limit, offset)) for id, page, actions_tpsv, status, outcome in cursor.fetchall(): command_records.append(self._row_to_command_record(id, page, actions_tpsv, status, outcome)) return command_records def __len__(self) -> int: with self.store._connect() as connection, connection.cursor() as cursor: cursor.execute('SELECT COUNT(*) FROM `command` WHERE `command_batch` = %s', (self.batch_id,)) (count,) = cursor.fetchone() return count def make_plans_pending(self, offset: int, limit: int) -> List[CommandPending]: with self.store._connect() as connection: command_ids = [] # List[int] with connection.cursor() as cursor: # the extra subquery layer below is necessary to work around a MySQL/MariaDB restriction; # based on https://stackoverflow.com/a/24777566/1420237 cursor.execute('''SELECT `command_id` FROM `command` WHERE `command_id` IN ( SELECT * FROM ( SELECT `command_id` FROM `command` WHERE `command_batch` = %s ORDER BY `command_id` ASC LIMIT %s OFFSET %s ) AS temporary_table) AND `command_status` = %s ORDER BY `command_id` ASC FOR UPDATE''', (self.batch_id, limit, offset, DatabaseStore._COMMAND_STATUS_PLAN)) for (command_id,) in cursor.fetchall(): command_ids.append(command_id) if not command_ids: connection.commit() # finish the FOR UPDATE return [] with connection.cursor() as cursor: cursor.executemany('''UPDATE `command` SET `command_status` = %s WHERE `command_id` = %s AND `command_batch` = %s''', zip(itertools.repeat(DatabaseStore._COMMAND_STATUS_PENDING), command_ids, itertools.repeat(self.batch_id))) connection.commit() command_records = [] with connection.cursor() as cursor: cursor.execute('''SELECT `command_id`, `command_page`, `actions_tpsv`, `command_status`, `command_outcome` FROM `command` JOIN `actions` ON `command_actions_id` = `actions_id` WHERE `command_id` IN (%s)''' % ', '.join(['%s'] * len(command_ids)), command_ids) for id, page, actions_tpsv, status, outcome in cursor.fetchall(): assert status == DatabaseStore._COMMAND_STATUS_PENDING assert outcome is None command_record = self._row_to_command_record(id, page, actions_tpsv, status, outcome) assert isinstance(command_record, CommandPending) command_records.append(command_record) return command_records def store_finish(self, command_finish: CommandFinish) -> None: last_updated = _now() last_updated_utc_timestamp = self.store._datetime_to_utc_timestamp(last_updated) status, outcome = self._command_finish_to_row(command_finish) with self.store._connect() as connection, connection.cursor() as cursor: cursor.execute('''UPDATE `command` SET `command_status` = %s, `command_outcome` = %s WHERE `command_id` = %s AND `command_batch` = %s''', (status, json.dumps(outcome), command_finish.id, self.batch_id)) cursor.execute('''UPDATE `batch` SET `batch_last_updated_utc_timestamp` = %s WHERE `batch_id` = %s''', (last_updated_utc_timestamp, self.batch_id)) connection.commit() cursor.execute('''SELECT 1 FROM `command` WHERE `command_batch` = %s AND `command_status` IN (%s, %s) LIMIT 1''', (self.batch_id, DatabaseStore._COMMAND_STATUS_PLAN, DatabaseStore._COMMAND_STATUS_PENDING)) if not cursor.fetchone(): + # close the batch cursor.execute('''UPDATE `batch` SET `batch_status` = %s WHERE `batch_id` = %s''', (DatabaseStore._BATCH_STATUS_CLOSED, self.batch_id)) connection.commit() + self.store._stop_background_by_id(self.batch_id) def __eq__(self, value: Any) -> bool: # limited test to avoid overly expensive full comparison return type(value) is _BatchCommandRecordsDatabase and \ self.batch_id == value.batch_id class _StringTableStore: """Encapsulates access to a string that has been extracted into a separate table. The separate table is expected to have three columns: an automatically incrementing ID, an unsigned integer hash (the first four bytes of the SHA2-256 hash of the string), and the string itself. IDs for the least recently used strings are cached, but to look up the string for an ID, callers should use a plain SQL JOIN for now.""" def __init__(self, table_name: str, id_column_name: str, hash_column_name: str, string_column_name: str): self.table_name = table_name self.id_column_name = id_column_name self.hash_column_name = hash_column_name self.string_column_name = string_column_name self._cache = cachetools.LRUCache(maxsize=1024) # type: cachetools.LRUCache[str, int] self._cache_lock = threading.RLock() def _hash(self, string: str) -> int: hex = hashlib.sha256(string.encode('utf8')).hexdigest() return int(hex[:8], base=16) @cachetools.cachedmethod(operator.attrgetter('_cache'), key=lambda connection, string: string, lock=operator.attrgetter('_cache_lock')) def acquire_id(self, connection: pymysql.connections.Connection, string: str) -> int: hash = self._hash(string) with connection.cursor() as cursor: cursor.execute('''SELECT `%s` FROM `%s` WHERE `%s` = %%s FOR UPDATE''' % (self.id_column_name, self.table_name, self.hash_column_name), (hash,)) result = cursor.fetchone() if result: connection.commit() # finish the FOR UPDATE return result[0] with connection.cursor() as cursor: cursor.execute('''INSERT INTO `%s` (`%s`, `%s`) VALUES (%%s, %%s)''' % (self.table_name, self.string_column_name, self.hash_column_name), (string, hash)) string_id = cursor.lastrowid connection.commit() return string_id diff --git a/tables.sql b/tables.sql index 4819e47..73c2326 100644 --- a/tables.sql +++ b/tables.sql @@ -1,45 +1,67 @@ CREATE TABLE batch ( batch_id int unsigned NOT NULL PRIMARY KEY AUTO_INCREMENT, batch_user_name varchar(255) binary NOT NULL, batch_local_user_id int unsigned NOT NULL, batch_global_user_id int unsigned NOT NULL, batch_domain_id int unsigned NOT NULL, batch_created_utc_timestamp int unsigned NOT NULL, batch_last_updated_utc_timestamp int unsigned NOT NULL, batch_status int unsigned NOT NULL ) CHARACTER SET = 'utf8mb4' COLLATE = 'utf8mb4_bin'; CREATE TABLE command ( command_id int unsigned NOT NULL PRIMARY KEY AUTO_INCREMENT, command_batch int unsigned NOT NULL, command_page text NOT NULL, command_actions_id int unsigned NOT NULL, command_status int unsigned NOT NULL, command_outcome text ) CHARACTER SET = 'utf8mb4' COLLATE = 'utf8mb4_bin'; CREATE INDEX command_batch_status ON command (command_batch, command_status); CREATE TABLE domain ( domain_id int unsigned NOT NULL PRIMARY KEY AUTO_INCREMENT, domain_hash int unsigned NOT NULL, -- first four bytes of the SHA2-256 hash of the domain_name domain_name varchar(255) binary NOT NULL ) CHARACTER SET = 'utf8mb4' COLLATE = 'utf8mb4_bin'; CREATE INDEX domain_hash ON domain (domain_hash); CREATE TABLE actions ( actions_id int unsigned NOT NULL PRIMARY KEY AUTO_INCREMENT, actions_hash int unsigned NOT NULL, -- first four bytes of the SHA2-256 hash of the actions_tpsv actions_tpsv text NOT NULL ) CHARACTER SET = 'utf8mb4' COLLATE = 'utf8mb4_bin'; CREATE INDEX actions_hash ON actions (actions_hash); + +CREATE TABLE background ( + background_id int unsigned NOT NULL PRIMARY KEY AUTO_INCREMENT, + background_batch int unsigned NOT NULL, + background_auth text, + background_started_utc_timestamp int unsigned NOT NULL, + background_started_user_name varchar(255) binary NOT NULL, + background_started_local_user_id int unsigned NOT NULL, + background_started_global_user_id int unsigned NOT NULL, + background_stopped_utc_timestamp int unsigned, + background_stopped_user_name varchar(255) binary, + background_stopped_local_user_id int unsigned, + background_stopped_global_user_id int unsigned +) +CHARACTER SET = 'utf8mb4' +COLLATE = 'utf8mb4_bin'; + +-- index for finding the backgrounds of a batch, optionally limited to just the ones not yet stopped +CREATE INDEX background_batch_stopped ON background (background_batch, background_stopped_utc_timestamp); + +-- index for finding backgrounds not yet stopped, for any batch +CREATE INDEX background_stopped_batch ON background (background_stopped_utc_timestamp, background_batch); diff --git a/templates/batch.html b/templates/batch.html index be1b981..e1bcbd2 100644 --- a/templates/batch.html +++ b/templates/batch.html @@ -1,37 +1,46 @@ {% extends "base.html" %} {% block main %} Batch #{{ batch.id }} by {{ render_batch_user(batch) }} - - Targeting {{ batch.domain }}. - Created {{ batch.created | render_datetime }}, - last updated {{ batch.last_updated | render_datetime }}. - + + + Targeting {{ batch.domain }}. + Created {{ batch.created | render_datetime }}, + last updated {{ batch.last_updated | render_datetime }}. + {% if can_stop_background() %} + + Stop batch running in background + {% endif %} + + {% if offset > 0 %} Previous {% else %} Previous {% endif %} {% if offset + limit < (batch.command_records | length) %} Next {% else %} Next {% endif %} {% set command_records = batch.command_records.get_slice(offset, limit) %} {% for command_record in command_records %} {{ render_command_record(command_record, batch.domain) }} {% endfor %} {% if can_run_commands(command_records) %} Run these commands + {% if can_start_background() %} + Run whole batch in background + {% endif %} {% endif %} {% endblock %} diff --git a/test_store.py b/test_store.py index d2e4d87..5287314 100644 --- a/test_store.py +++ b/test_store.py @@ -1,306 +1,393 @@ import datetime import json import os import pymysql import pytest # type: ignore import random import re import string import time from typing import List, Optional, Tuple from batch import OpenBatch, ClosedBatch from command import CommandRecord, CommandEdit, CommandNoop from store import InMemoryStore, DatabaseStore, _BatchCommandRecordsDatabase, _StringTableStore from test_batch import newBatch1 from test_command import commandPlan1, commandPending1, commandEdit1, commandNoop1, commandPageMissing1, commandEditConflict1, commandMaxlagExceeded1, commandBlocked1, blockinfo, commandBlocked2, commandWikiReadOnly1, commandWikiReadOnly2 from test_utils import FakeSession fake_session = FakeSession({ 'query': { 'userinfo': { 'id': 6198807, 'name': 'Lucas Werkmeister', 'centralids': { 'CentralAuth': 46054761, 'local': 6198807 }, 'attachedlocal': { 'CentralAuth': '', 'local': '' } } } }) fake_session.host = 'https://commons.wikimedia.org' def test_InMemoryStore_store_batch_command_ids(): open_batch = InMemoryStore().store_batch(newBatch1, fake_session) assert len(open_batch.command_records) == 2 [command_record_1, command_record_2] = open_batch.command_records.get_slice(0, 2) assert command_record_1.id != command_record_2.id def test_InMemoryStore_store_batch_batch_ids(): store = InMemoryStore() open_batch_1 = store.store_batch(newBatch1, fake_session) open_batch_2 = store.store_batch(newBatch1, fake_session) assert open_batch_1.id != open_batch_2.id def test_InMemoryStore_store_batch_metadata(): open_batch = InMemoryStore().store_batch(newBatch1, fake_session) assert open_batch.user_name == 'Lucas Werkmeister' assert open_batch.local_user_id == 6198807 assert open_batch.global_user_id == 46054761 assert open_batch.domain == 'commons.wikimedia.org' def test_InMemoryStore_get_batch(): store = InMemoryStore() open_batch = store.store_batch(newBatch1, fake_session) assert open_batch is store.get_batch(open_batch.id) def test_InMemoryStore_get_batch_None(): assert InMemoryStore().get_batch(0) is None def test_InMemoryStore_get_latest_batches(): store = InMemoryStore() open_batches = [] for i in range(25): open_batches.append(store.store_batch(newBatch1, fake_session)) open_batches.reverse() assert open_batches[:10] == store.get_latest_batches() def test_InMemoryStore_closes_batch(): store = InMemoryStore() open_batch = store.store_batch(newBatch1, fake_session) [command_record_1, command_record_2] = open_batch.command_records.get_slice(0, 2) open_batch.command_records.store_finish(CommandNoop(command_record_1.id, command_record_1.command, revision=1)) assert type(store.get_batch(open_batch.id)) is OpenBatch open_batch.command_records.store_finish(CommandNoop(command_record_2.id, command_record_2.command, revision=2)) assert type(store.get_batch(open_batch.id)) is ClosedBatch +# TODO add tests for InMemoryStore start_background + stop_background + @pytest.fixture(scope="module") def fresh_database_connection_params(): if 'MARIADB_ROOT_PASSWORD' not in os.environ: pytest.skip('MariaDB credentials not provided') connection = pymysql.connect(host='localhost', user='root', password=os.environ['MARIADB_ROOT_PASSWORD']) database_name = 'quickcategories_test_' + ''.join(random.choice(string.ascii_lowercase + string.digits) for i in range(16)) user_name = 'quickcategories_test_user_' + ''.join(random.choice(string.ascii_lowercase + string.digits) for i in range(16)) user_password = 'quickcategories_test_password_' + ''.join(random.choice(string.ascii_lowercase + string.digits) for i in range(16)) try: with connection.cursor() as cursor: cursor.execute('CREATE DATABASE `%s`' % database_name) cursor.execute('GRANT ALL PRIVILEGES ON `%s`.* TO `%s` IDENTIFIED BY %%s' % (database_name, user_name), (user_password,)) cursor.execute('USE `%s`' % database_name) with open('tables.sql') as tables: queries = tables.read() # PyMySQL does not support multiple queries in execute(), so we have to split for query in queries.split(';'): query = query.strip() if query: cursor.execute(query) connection.commit() yield {'host': 'localhost', 'user': user_name, 'password': user_password, 'db': database_name} finally: with connection.cursor() as cursor: cursor.execute('DROP DATABASE IF EXISTS `%s`' % database_name) cursor.execute('DROP USER IF EXISTS `%s`' % user_name) connection.commit() connection.close() @pytest.fixture def database_connection_params(fresh_database_connection_params): connection = pymysql.connect(**fresh_database_connection_params) try: with open('tables.sql') as tables: queries = tables.read() with connection.cursor() as cursor: for table in re.findall(r'CREATE TABLE ([^ ]+) ', queries): cursor.execute('DELETE FROM `%s`' % table) # more efficient than TRUNCATE TABLE on my system :/ # cursor.execute('ALTER TABLE `%s` AUTO_INCREMENT = 1' % table) # currently not necessary connection.commit() finally: connection.close() return fresh_database_connection_params @pytest.fixture(params=[InMemoryStore, DatabaseStore]) def store(request): if request.param is InMemoryStore: yield InMemoryStore() elif request.param is DatabaseStore: database_connection_params = request.getfixturevalue('database_connection_params') yield DatabaseStore(database_connection_params) else: raise ValueError('Unknown param!') def test_DatabaseStore_store_batch(database_connection_params): store = DatabaseStore(database_connection_params) open_batch = store.store_batch(newBatch1, fake_session) command2 = open_batch.command_records.get_slice(1, 1)[0] with store._connect() as connection: with connection.cursor() as cursor: cursor.execute('SELECT `command_page`, `actions_tpsv` FROM `command` JOIN `actions` on `command_actions_id` = `actions_id` WHERE `command_id` = %s AND `command_batch` = %s', (command2.id, open_batch.id)) command2_page, command2_actions_tpsv = cursor.fetchone() assert command2_page == command2.command.page assert command2_actions_tpsv == command2.command.actions_tpsv() def test_DatabaseStore_get_batch(store): stored_batch = store.store_batch(newBatch1, fake_session) loaded_batch = store.get_batch(stored_batch.id) assert loaded_batch.id == stored_batch.id assert loaded_batch.user_name == 'Lucas Werkmeister' assert loaded_batch.local_user_id == 6198807 assert loaded_batch.global_user_id == 46054761 assert loaded_batch.domain == 'commons.wikimedia.org' assert len(loaded_batch.command_records) == 2 assert loaded_batch.command_records.get_slice(0, 2) == stored_batch.command_records.get_slice(0, 2) def test_DatabaseStore_get_batch_missing(store): loaded_batch = store.get_batch(1) assert loaded_batch is None def test_DatabaseStore_update_batch(database_connection_params): store = DatabaseStore(database_connection_params) stored_batch = store.store_batch(newBatch1, fake_session) loaded_batch = store.get_batch(stored_batch.id) [command_plan_1, command_plan_2] = loaded_batch.command_records.get_slice(0, 2) command_edit = CommandEdit(command_plan_1.id, command_plan_1.command, 1234, 1235) loaded_batch.command_records.store_finish(command_edit) command_edit_loaded = loaded_batch.command_records.get_slice(0, 1)[0] assert command_edit == command_edit_loaded command_noop = CommandNoop(command_plan_1.id, command_plan_1.command, 1234) time.sleep(1) # make sure that this update increases last_updated loaded_batch.command_records.store_finish(command_noop) command_noop_loaded = loaded_batch.command_records.get_slice(0, 1)[0] assert command_noop == command_noop_loaded assert stored_batch.command_records.get_slice(0, 2) == loaded_batch.command_records.get_slice(0, 2) # TODO ideally, the timestamps on stored_batch and loaded_batch would update as well reloaded_batch = store.get_batch(stored_batch.id) assert reloaded_batch.last_updated > reloaded_batch.created def test_DatabaseStore_closes_batch(store): open_batch = store.store_batch(newBatch1, fake_session) [command_record_1, command_record_2] = open_batch.command_records.get_slice(0, 2) open_batch.command_records.store_finish(CommandNoop(command_record_1.id, command_record_1.command, revision=1)) assert type(store.get_batch(open_batch.id)) is OpenBatch open_batch.command_records.store_finish(CommandNoop(command_record_2.id, command_record_2.command, revision=2)) assert type(store.get_batch(open_batch.id)) is ClosedBatch def test_DatabaseStore_get_latest_batches(store): open_batches = [] for i in range(25): open_batches.append(store.store_batch(newBatch1, fake_session)) open_batches.reverse() assert open_batches[:10] == store.get_latest_batches() def test_DatabaseStore_datetime_to_utc_timestamp(): store = DatabaseStore({}) dt = datetime.datetime(2019, 3, 17, 13, 23, 28, tzinfo=datetime.timezone.utc) assert store._datetime_to_utc_timestamp(dt) == 1552829008 @pytest.mark.parametrize('dt', [ datetime.datetime.now(), datetime.datetime.utcnow(), datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=1))), datetime.datetime(2019, 3, 17, 13, 23, 28, 251638, tzinfo=datetime.timezone.utc) ]) def test_DatabaseStore_datetime_to_utc_timestamp_invalid_timezone(dt): store = DatabaseStore({}) with pytest.raises(AssertionError): store._datetime_to_utc_timestamp(dt) def test_DatabaseStore_utc_timestamp_to_datetime(): store = DatabaseStore({}) dt = datetime.datetime(2019, 3, 17, 13, 23, 28, tzinfo=datetime.timezone.utc) assert store._utc_timestamp_to_datetime(1552829008) == dt +def test_DatabaseStore_start_background_inserts_row(database_connection_params): + store = DatabaseStore(database_connection_params) + open_batch = store.store_batch(newBatch1, fake_session) + store.start_background(open_batch, fake_session) + with store._connect() as connection, connection.cursor() as cursor: + cursor.execute('SELECT `background_started_user_name`, `background_auth` FROM `background`') + assert cursor.rowcount == 1 + user_name, auth = cursor.fetchone() + assert user_name == 'Lucas Werkmeister' + assert json.loads(auth) == {'resource_owner_key': 'fake resource owner key', + 'resource_owner_secret': 'fake resource owner secret'} + +def test_DatabaseStore_start_background_does_not_insert_extra_row(database_connection_params): + store = DatabaseStore(database_connection_params) + open_batch = store.store_batch(newBatch1, fake_session) + store.start_background(open_batch, fake_session) + with store._connect() as connection, connection.cursor() as cursor: + cursor.execute('SELECT `background_id`, `background_started_utc_timestamp` FROM `background`') + assert cursor.rowcount == 1 + background_id, background_started_utc_timestamp = cursor.fetchone() + store.start_background(open_batch, fake_session) # should be no-op + with store._connect() as connection, connection.cursor() as cursor: + cursor.execute('SELECT `background_id`, `background_started_utc_timestamp` FROM `background`') + assert cursor.rowcount == 1 + assert (background_id, background_started_utc_timestamp) == cursor.fetchone() + +def test_DatabaseStore_stop_background_updates_row_removes_auth(database_connection_params): + store = DatabaseStore(database_connection_params) + open_batch = store.store_batch(newBatch1, fake_session) + store.start_background(open_batch, fake_session) + store.stop_background(open_batch, fake_session) + with store._connect() as connection, connection.cursor() as cursor: + cursor.execute('SELECT `background_auth`, `background_stopped_utc_timestamp`, `background_stopped_user_name` FROM `background`') + assert cursor.rowcount == 1 + auth, stopped_utc_timestamp, stopped_user_name = cursor.fetchone() + assert stopped_utc_timestamp > 0 + assert stopped_user_name == 'Lucas Werkmeister' + assert auth is None + +def test_DatabaseStore_stop_background_without_session(database_connection_params): + store = DatabaseStore(database_connection_params) + open_batch = store.store_batch(newBatch1, fake_session) + store.start_background(open_batch, fake_session) + store.stop_background(open_batch) + with store._connect() as connection, connection.cursor() as cursor: + cursor.execute('SELECT `background_stopped_utc_timestamp`, `background_stopped_user_name` FROM `background`') + assert cursor.rowcount == 1 + stopped_utc_timestamp, stopped_user_name = cursor.fetchone() + assert stopped_utc_timestamp > 0 + assert stopped_user_name is None + +def test_DatabaseStore_stop_background_noop(store): + open_batch = store.store_batch(newBatch1, fake_session) + store.stop_background(open_batch) + # no error + +def test_DatabaseStore_stop_background_multiple_closes_all_raises_exception(database_connection_params): + store = DatabaseStore(database_connection_params) + open_batch = store.store_batch(newBatch1, fake_session) + store.start_background(open_batch, fake_session) + with store._connect() as connection, connection.cursor() as cursor: + cursor.execute('INSERT INTO `background` (`background_batch`, `background_auth`, `background_started_utc_timestamp`, `background_started_user_name`, `background_started_local_user_id`, `background_started_global_user_id`) SELECT `background_batch`, `background_auth`, `background_started_utc_timestamp`, `background_started_user_name`, `background_started_local_user_id`, `background_started_global_user_id` FROM `background`') + connection.commit() + with pytest.raises(RuntimeError, match='Should have stopped at most 1 background operation, actually affected 2!'): + store.stop_background(open_batch) + with store._connect() as connection, connection.cursor() as cursor: + cursor.execute('SELECT 1 FROM `background` WHERE `background_stopped_utc_timestamp` IS NOT NULL') + assert cursor.rowcount == 2 + cursor.execute('SELECT 1 FROM `background` WHERE `background_stopped_utc_timestamp` IS NULL') + assert cursor.rowcount == 0 + +def test_DatabaseStore_closing_batch_stops_background(database_connection_params): + store = DatabaseStore(database_connection_params) + open_batch = store.store_batch(newBatch1, fake_session) + store.start_background(open_batch, fake_session) + [command_record_1, command_record_2] = open_batch.command_records.get_slice(0, 2) + open_batch.command_records.store_finish(CommandNoop(command_record_1.id, command_record_1.command, revision=1)) + open_batch.command_records.store_finish(CommandNoop(command_record_2.id, command_record_2.command, revision=2)) + with store._connect() as connection, connection.cursor() as cursor: + cursor.execute('SELECT `background_stopped_utc_timestamp`, `background_stopped_user_name` FROM `background`') + assert cursor.rowcount == 1 + stopped_utc_timestamp, stopped_user_name = cursor.fetchone() + assert stopped_utc_timestamp > 0 + assert stopped_user_name is None + command_unfinishes_and_rows = [ (commandPlan1, (DatabaseStore._COMMAND_STATUS_PLAN, None)), (commandPending1, (DatabaseStore._COMMAND_STATUS_PENDING, None)), ] # type: List[Tuple[CommandRecord, Tuple[int, Optional[dict]]]] command_finishes_and_rows = [ (commandEdit1, (DatabaseStore._COMMAND_STATUS_EDIT, {'base_revision': 1234, 'revision': 1235})), (commandNoop1, (DatabaseStore._COMMAND_STATUS_NOOP, {'revision': 1234})), (commandPageMissing1, (DatabaseStore._COMMAND_STATUS_PAGE_MISSING, {'curtimestamp': '2019-03-11T23:26:02Z'})), (commandEditConflict1, (DatabaseStore._COMMAND_STATUS_EDIT_CONFLICT, {})), (commandMaxlagExceeded1, (DatabaseStore._COMMAND_STATUS_MAXLAG_EXCEEDED, {'retry_after_utc_timestamp': 1552749842})), (commandBlocked1, (DatabaseStore._COMMAND_STATUS_BLOCKED, {'auto': False, 'blockinfo': blockinfo})), (commandBlocked2, (DatabaseStore._COMMAND_STATUS_BLOCKED, {'auto': False, 'blockinfo': None})), (commandWikiReadOnly1, (DatabaseStore._COMMAND_STATUS_WIKI_READ_ONLY, {'reason': 'maintenance'})), (commandWikiReadOnly2, (DatabaseStore._COMMAND_STATUS_WIKI_READ_ONLY, {'reason': None})), ] # type: List[Tuple[CommandRecord, Tuple[int, Optional[dict]]]] @pytest.mark.parametrize('command_finish, expected_row', command_finishes_and_rows) def test_BatchCommandRecordsDatabase_command_finish_to_row(command_finish, expected_row): actual_row = _BatchCommandRecordsDatabase(0, DatabaseStore({}))._command_finish_to_row(command_finish) assert expected_row == actual_row @pytest.mark.parametrize('expected_command_record, row', command_unfinishes_and_rows + command_finishes_and_rows) def test_BatchCommandRecordsDatabase_row_to_command_record(expected_command_record, row): status, outcome = row outcome_json = json.dumps(outcome) if outcome else None full_row = expected_command_record.id, expected_command_record.command.page, expected_command_record.command.actions_tpsv(), status, outcome_json actual_command_record = _BatchCommandRecordsDatabase(0, DatabaseStore({}))._row_to_command_record(*full_row) assert expected_command_record == actual_command_record @pytest.mark.parametrize('string, expected_hash', [ # all hashes obtained in MariaDB via SELECT CAST(CONV(SUBSTRING(SHA2(**string**, 256), 1, 8), 16, 10) AS unsigned int); ('', 3820012610), ('test.wikipedia.org', 3277830609), ('äöü', 3157433791), ('☺', 3752208785), ('🤔', 1622577385), ]) def test_StringTableStore_hash(string, expected_hash): store = _StringTableStore('', '', '', '') actual_hash = store._hash(string) assert expected_hash == actual_hash def test_StringTableStore_acquire_id_database(database_connection_params): connection = pymysql.connect(**database_connection_params) try: store = _StringTableStore('domain', 'domain_id', 'domain_hash', 'domain_name') store.acquire_id(connection, 'test.wikipedia.org') with connection.cursor() as cursor: cursor.execute('SELECT domain_name FROM domain WHERE domain_hash = 3277830609') result = cursor.fetchone() assert result == ('test.wikipedia.org',) with store._cache_lock: store._cache.clear() store.acquire_id(connection, 'test.wikipedia.org') with connection.cursor() as cursor: cursor.execute('SELECT COUNT(*) FROM domain') result = cursor.fetchone() assert result == (1,) finally: connection.close() def test_StringTableStore_acquire_id_cached(): store = _StringTableStore('', '', '', '') with store._cache_lock: store._cache['test.wikipedia.org'] = 1 assert store.acquire_id(None, 'test.wikipedia.org') == 1 diff --git a/test_utils.py b/test_utils.py index e46ca37..ea5503f 100644 --- a/test_utils.py +++ b/test_utils.py @@ -1,18 +1,25 @@ +import requests +import requests_oauthlib # type: ignore + + class FakeSession: def __init__(self, get_response, post_response=None): self.get_response = get_response self.post_response = post_response self.host = None + self.session = requests.Session() + self.session.auth = requests_oauthlib.OAuth1(client_key='fake client key', client_secret='fake client secret', + resource_owner_key='fake resource owner key', resource_owner_secret='fake resource owner secret') def get(self, *args, **kwargs): return self.get_response def post(self, *args, **kwargs): if self.post_response: if isinstance(self.post_response, BaseException): raise self.post_response else: return self.post_response else: raise NotImplementedError
- Targeting {{ batch.domain }}. - Created {{ batch.created | render_datetime }}, - last updated {{ batch.last_updated | render_datetime }}. -
+ Targeting {{ batch.domain }}. + Created {{ batch.created | render_datetime }}, + last updated {{ batch.last_updated | render_datetime }}. + {% if can_stop_background() %} + + Stop batch running in background + {% endif %} +