diff --git a/app.py b/app.py
index eebb39c..e883476 100644
--- a/app.py
+++ b/app.py
@@ -1,323 +1,337 @@
# -*- coding: utf-8 -*-
import flask
import functools
from io import StringIO
import mwoauth
import os
import pymysql
import random
import requests
import requests_oauthlib
import string
import toolforge
import urllib.parse
import yaml
try:
from pygments.lexers import ShExCLexer
except ImportError:
have_pygments = False
else:
import pygments
from pygments.formatters import HtmlFormatter
have_pygments = True
from job import Job, null_job
from job_store import SqlJobStore, LocalFileJobStore
from job_runner import GridEngineJobRunner
from job_manager import JobManager, RejectJobDueToBlocks, RejectJobDueToPendingJobs
app = flask.Flask(__name__)
app.before_request(toolforge.redirect_to_https)
toolforge.set_user_agent('wd-shex-infer', email='mail@lucaswerkmeister.de')
user_agent = requests.utils.default_user_agent()
-__dir__ = os.path.dirname(__file__)
-try:
- with open(os.path.join(__dir__, 'config.yaml'), 'r', encoding='utf-8') as config_file:
- app.config.update(yaml.safe_load(config_file))
-except FileNotFoundError:
+@decorator.decorator
+def read_private(func, *args, **kwargs):
+ try:
+ f = args[0]
+ fd = f.fileno()
+ except AttributeError:
+ pass
+ except IndexError:
+ pass
+ else:
+ mode = os.stat(fd).st_mode
+ if (stat.S_IRGRP | stat.S_IROTH) & mode:
+ raise ValueError(getattr(f, "name", "config file") +
+ ' is readable to others, ' +
+ 'must be exclusively user-readable!')
+ return func(*args, **kwargs)
+
+has_config = app.config.from_file('config.yaml', load=read_private(yaml.safe_load), silent=True)
+if not has_config:
print('config.yaml file not found, assuming local development setup')
app.secret_key = 'fake secret key so we can still use flask.session'
-if 'oauth' in app.config:
- consumer_token = mwoauth.ConsumerToken(app.config['oauth']['consumer_key'], app.config['oauth']['consumer_secret'])
+if 'OAUTH' in app.config:
+ consumer_token = mwoauth.ConsumerToken(app.config['OAUTH']['consumer_key'], app.config['OAUTH']['consumer_secret'])
else:
consumer_token = None
wikidata_url = 'https://www.wikidata.org/w/index.php'
-connection = pymysql.connect(charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor, **app.config['database'])
-job_store = LocalFileJobStore(SqlJobStore(connection), app.config['files']) # update require_connection() if implementations here change
-job_runner = GridEngineJobRunner(app.config['rdf2graph'])
-job_manager = JobManager(job_store, job_runner, app.config.get('blocks_directory'))
+connection = pymysql.connect(charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor, **app.config['DATABASE'])
+job_store = LocalFileJobStore(SqlJobStore(connection), app.config['FILES']) # update require_connection() if implementations here change
+job_runner = GridEngineJobRunner(app.config['RDF2GRAPH'])
+job_manager = JobManager(job_store, job_runner, app.config.get('BLOCKS_DIRECTORY'))
connection.close() # this connection was only to initialize the database, web requests have their own connections
@app.template_global()
def csrf_token():
if '_csrf_token' not in flask.session:
flask.session['_csrf_token'] = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(64))
return flask.session['_csrf_token']
@app.template_filter()
def time_element(datetime, previous_datetime=None):
date_format = '%d %b %Y'
separator = ', '
time_format = '%H:%M'
tz = ' UTC'
if previous_datetime is not None:
# no need to mention the time zone more than once
tz = ''
if datetime.date() == previous_datetime.date():
# no need to mention the same date again
date_format = ''
separator = ''
text = (datetime.strftime(date_format) +
separator +
datetime.strftime(time_format) +
tz)
return (flask.Markup(r''))
@app.template_filter()
def user_link(user_name):
return (flask.Markup(r'') +
flask.Markup(r'') +
flask.Markup.escape(user_name) +
flask.Markup(r'') +
flask.Markup(r''))
@app.template_filter()
def job_line(job):
return (flask.Markup(r'') +
flask.Markup.escape(job.title) +
flask.Markup(r', started on ') +
time_element(job.datetime_created) +
flask.Markup(r' by ') +
user_link(job.author_name))
def require_connection(function):
def ping(*args, **kwargs):
- connection = pymysql.connect(charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor, **app.config['database'])
+ connection = pymysql.connect(charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor, **app.config['DATABASE'])
job_store.job_store.connection = connection
try:
return function(*args, **kwargs)
finally:
connection.close()
functools.update_wrapper(ping, function)
return ping
def require_job(function):
def return_if_not_job(id, *args, **kwargs):
job = job_manager.get_by_id(id)
if job is None:
return 'no such job', 404
return function(job, *args, **kwargs)
return_if_not_job = require_connection(return_if_not_job)
functools.update_wrapper(return_if_not_job, function)
return return_if_not_job
def require_finished_job(function):
def return_if_not_finished(job, *args, **kwargs):
if job.datetime_first_stopped is None:
return 'not yet finished', 404
return function(job, *args, **kwargs)
return_if_not_finished_job = require_job(return_if_not_finished)
functools.update_wrapper(return_if_not_finished_job, function)
return return_if_not_finished_job
def render_template(template_name, add_manager_data=False, form_data=None, **kwargs):
if 'oauth_access_token' in flask.session:
identity = identify()
kwargs.setdefault('oauth_username', identity['username'])
if add_manager_data:
if 'blocks' not in kwargs:
kwargs['blocks'] = job_manager.get_blocks()
if 'pending_jobs' not in kwargs:
kwargs['pending_jobs'] = job_manager.get_pending_jobs()
if form_data:
kwargs.setdefault('title', form_data.get('title'))
kwargs.setdefault('description', form_data.get('description'))
kwargs.setdefault('url', form_data.get('url'))
kwargs.setdefault('sparql', form_data.get('sparql'))
return flask.render_template(template_name, **kwargs)
@app.route('/')
@require_connection
def index():
return render_template('index.html',
add_manager_data=True,
finished_jobs=job_manager.get_finished_jobs())
@app.route('/job/new', methods=['GET', 'POST'])
@require_connection
def new_job():
response = if_needs_oauth_redirect()
if response:
return response
identity = identify()
if not identity['confirmed_email']:
return 'must have confirmed email', 403
if identity['blocked']:
return 'must not be blocked', 403
if flask.request.method == 'GET':
return render_template('new-job.html',
add_manager_data=True)
form_data = flask.request.form
response = if_needs_csrf_redirect(form_data)
if response:
return response
job = null_job._replace(author_name=identity['username'],
title=form_data['title'],
description=form_data.get('description'),
url=form_data.get('url'),
input_sparql=StringIO(form_data['sparql']))
try:
job = job_manager.run(job)
except RejectJobDueToBlocks as reject:
return render_template('new_job.html',
add_manager_data=True,
blocks=reject.blocks,
form_data=form_data,
rejected_due_to_blocks=True)
except RejectJobDueToPendingJobs as reject:
return render_template('new-job.html',
add_manager_data=True,
pending_jobs=reject.pending_jobs,
form_data=form_data,
rejected_due_to_pending_jobs=True)
else:
return flask.redirect(flask.url_for('view_job', id=job.id))
@app.route('/job/')
@require_job
def view_job(job):
return render_template('job.html',
job=job,
wdqs_url=sparql_to_wdqs_url(job.input_sparql))
@app.route('/job//sparql')
@require_job
def view_job_sparql(job):
return flask.Response(job.input_sparql,
mimetype='application/sparql-query')
@app.route('/job//shex')
@require_finished_job
def view_job_shex(job):
if job.output_shex is None:
return 'this job did not produce any output', 410
if have_pygments and flask.request.accept_mimetypes.accept_html:
shex = job.output_shex.read()
job.output_shex.seek(0)
formatter = HtmlFormatter()
shexHtml = pygments.highlight(shex, ShExCLexer(), formatter)
return render_template('shex.html',
title=job.title,
css=formatter.get_style_defs('#shex'),
shexHtml=shexHtml)
else:
return flask.Response(job.output_shex,
mimetype='text/shex')
@app.route('/job//stdout')
@require_finished_job
def view_job_stdout(job):
return flask.Response(job.output_stdout,
mimetype='text/plain')
@app.route('/job//stderr')
@require_finished_job
def view_job_stderr(job):
return flask.Response(job.output_stderr,
mimetype='text/plain')
def if_needs_oauth_redirect():
if not consumer_token:
return None # development setup
if 'oauth_access_token' in flask.session:
return None # already authenticated
redirect, request_token = mwoauth.initiate(wikidata_url,
consumer_token,
user_agent=user_agent)
flask.session['oauth_request_token'] = dict(zip(request_token._fields, request_token))
flask.session['oauth_redirect_target'] = flask.url_for(flask.request.endpoint, **flask.request.view_args)
return flask.redirect(redirect)
def if_needs_csrf_redirect(form_data):
token = flask.session.pop('_csrf_token', None)
if not token or token != form_data.get('_csrf_token'):
return render_template('new-job.html',
add_manager_data=True,
form_data=form_data,
csrf_error=True)
else:
return None
@app.route('/oauth/callback')
def oauth_callback():
access_token = mwoauth.complete(wikidata_url,
consumer_token,
mwoauth.RequestToken(**flask.session['oauth_request_token']),
flask.request.query_string,
user_agent=user_agent)
flask.session['oauth_access_token'] = dict(zip(access_token._fields, access_token))
return flask.redirect(flask.session['oauth_redirect_target'])
def identify():
if not consumer_token:
return {'username': '###TEST USER###', 'fake': True} # development setup
access_token = mwoauth.AccessToken(**flask.session['oauth_access_token'])
return mwoauth.identify(wikidata_url,
consumer_token,
access_token)
def sparql_to_wdqs_url(sparql_io):
sparql_str = sparql_io.read(4096)
sparql_io.seek(0)
if len(sparql_str) < 4096:
return 'https://query.wikidata.org/#' + urllib.parse.quote(sparql_str)
else:
return None
diff --git a/config.yaml.example b/config.yaml.example
index 3f373c5..a9f468a 100644
--- a/config.yaml.example
+++ b/config.yaml.example
@@ -1,11 +1,11 @@
SECRET_KEY: ...
-oauth:
+OAUTH:
consumer_key: ...
consumer_secret: ...
-database:
+DATABASE:
database: s12345__wd_shex_import
host: tools.db.svc.eqiad.wmflabs
read_default_file: /data/project/MY-TOOL/replica.my.cnf
-files: /data/project/MY-TOOL/files
-rdf2graph: /data/project/MY-TOOL/RDF2Graph
-blocks_directory: /data/scratch/MY-TOOL # optional
+FILES: /data/project/MY-TOOL/files
+RDF2GRAPH: /data/project/MY-TOOL/RDF2Graph
+BLOCKS_DIRECTORY: /data/scratch/MY-TOOL # optional
diff --git a/requirements.txt b/requirements.txt
index 05297fe..29fe4fe 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,8 @@
-flask
+flask >= 2.0.0
mwoauth
pygments>=2.5.0
pymysql
pyyaml
requests
requests_oauthlib
toolforge