Skip to content

Commit

Permalink
Allow upload to different DBs, refs #28
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Jan 30, 2024
1 parent cc2c6ab commit f8fdde8
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 22 deletions.
14 changes: 12 additions & 2 deletions datasette_upload_csvs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,31 @@ async def upload_csvs(scope, receive, datasette, request):
]
if not dbs:
raise Forbidden("No mutable databases available")
db = dbs[0]

default_db = dbs[0]

# We need the ds_request to pass to render_template for CSRF tokens
ds_request = request

# We use the Starlette request object to handle file uploads
starlette_request = Request(scope, receive)
if starlette_request.method != "POST":
selected_db = ds_request.args.get("database")
databases = []
# If there are multiple databases let them choose
if len(dbs) > 1:
databases = [
{"name": db.name, "selected": db.name == selected_db} for db in dbs
]
return Response.html(
await datasette.render_template(
"upload_csv.html", {"database_name": db.name}, request=ds_request
"upload_csv.html", {"databases": databases}, request=ds_request
)
)

formdata = await starlette_request.form()
database_name = formdata.get("database") or default_db.name
db = datasette.get_database(database_name)
csv = formdata["csv"]
# csv.file is a SpooledTemporaryFile. csv.filename is the filename
table_name = formdata.get("table")
Expand Down
19 changes: 17 additions & 2 deletions datasette_upload_csvs/templates/upload_csv.html
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,19 @@

{% block content %}
<h1>Upload a CSV</h1>
<p>A table will be created in database "<strong><a href="{{ urls.database(database_name) }}">{{ database_name }}</a></strong>".</p>

<form action="/-/upload-csvs" id="uploadForm" method="post" enctype="multipart/form-data">

{% if databases %}
<p><label>Database &nbsp; &nbsp;
<select id="id_database" name="database">
{% for database in databases %}
<option{% if database.selected %} selected{% endif %}>{{ database.name }}</option>
{% endfor %}
</select></label>
</p>
{% endif %}

<div id="file-drop">
<input type="hidden" name="csrftoken" value="{{ csrftoken() }}">
<input type="file" name="csv" id="csvUpload">
Expand All @@ -57,7 +68,7 @@ <h1>Upload a CSV</h1>
<p style="margin-bottom: -0.8em;font-size: 0.8em; display: none;" id="progress-label">Uploading...</p>
<progress class="progress" value="0" max="100">Uploading...</progress>
<p style="margin-top: 1em">
<label for="id_table_name">Table name</label>
<label for="id_table_name">Table name</label>&nbsp; &nbsp;
<input required id="id_table_name" type="text" name="table_name">
</p>
<p><input type="submit" value="Upload file" class="button"></p>
Expand All @@ -72,6 +83,7 @@ <h1>Upload a CSV</h1>
var progressLabel = document.getElementById("progress-label");
var label = dropArea.getElementsByTagName("label")[0];
var tableName = document.getElementById("id_table_name");
var databaseName = document.getElementById("id_database");

// State that holds the most-recent uploaded File, from a FileList
let currentFile = null;
Expand Down Expand Up @@ -199,6 +211,9 @@ <h1>Upload a CSV</h1>
formData.append("csrftoken", "{{ csrftoken() }}");
formData.append("csv", currentFile);
formData.append("table", tableName.value);
if (databaseName) {
formData.append("database", databaseName.value);
}
xhr.send(formData);
});
</script>
Expand Down
48 changes: 30 additions & 18 deletions tests/test_datasette_upload_csvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from asgi_lifespan import LifespanManager
import json
from unittest.mock import ANY
import pathlib
import pytest
import httpx
import sqlite_utils
Expand Down Expand Up @@ -113,23 +114,28 @@ async def test_menu(tmpdir, auth, has_database):
),
)
@pytest.mark.parametrize("use_xhr", (True, False))
@pytest.mark.parametrize("database", ("data", "data2"))
async def test_upload(
tmpdir, filename, content, expected_table, expected_rows, use_xhr
tmpdir, filename, content, expected_table, expected_rows, use_xhr, database
):
expected_url = "/data/{}".format(tilde_encode(expected_table))
expected_url = "/{}/{}".format(database, tilde_encode(expected_table))
path = str(tmpdir / "data.db")
db = sqlite_utils.Database(path)
db.vacuum()
db.enable_wal()
db["already_exists"].insert({"id": 1})
binary_content = content
# Trick to avoid a 12MB string being part of the pytest rendered test name:
if content == "LATIN1_AFTER_FIRST_2KB":
binary_content = LATIN1_AFTER_FIRST_2KB

db["hello"].insert({"hello": "world"})

datasette = Datasette([path])
path2 = str(tmpdir / "data2.db")
dbs_by_name = {}
for p in (path, path2):
db = sqlite_utils.Database(p)
dbs_by_name[pathlib.Path(p).stem] = db
db.vacuum()
db.enable_wal()
db["already_exists"].insert({"id": 1})
binary_content = content
# Trick to avoid a 12MB string being part of the pytest rendered test name:
if content == "LATIN1_AFTER_FIRST_2KB":
binary_content = LATIN1_AFTER_FIRST_2KB

db["hello"].insert({"hello": "world"})

datasette = Datasette([path, path2])

cookies = {"ds_actor": datasette.sign({"a": {"id": "root"}}, "actor")}

Expand All @@ -141,6 +147,7 @@ async def test_upload(
'<form action="/-/upload-csvs" id="uploadForm" method="post"'
in response.text
)
assert '<select id="id_database"' in response.text
csrftoken = response.cookies["ds_csrftoken"]
cookies["ds_csrftoken"] = csrftoken

Expand All @@ -153,7 +160,11 @@ async def test_upload(
else ""
),
cookies=cookies,
data={"csrftoken": csrftoken, "xhr": "1" if use_xhr else ""},
data={
"csrftoken": csrftoken,
"xhr": "1" if use_xhr else "",
"database": database,
},
files=files,
)
if use_xhr:
Expand All @@ -167,7 +178,7 @@ async def test_upload(
iterations = 0
while True:
response = await client.get(
"http://localhost/data/_csv_progress_.json?_shape=array"
f"http://localhost/{database}/_csv_progress_.json?_shape=array"
)
rows = json.loads(response.content)
assert 1 == len(rows)
Expand All @@ -178,10 +189,11 @@ async def test_upload(
break
iterations += 1
assert iterations < fail_after, "Took too long: {}".format(row)
await asyncio.sleep(0.5)
await asyncio.sleep(0.2)

# Give time for last operation to complete:
await asyncio.sleep(0.5)
db = dbs_by_name[database]
await asyncio.sleep(0.2)
rows = list(db[expected_table].rows)
assert rows == expected_rows

Expand Down

0 comments on commit f8fdde8

Please sign in to comment.