Skip to content

Commit

Permalink
Merge pull request #279 from bacpop/i273
Browse files Browse the repository at this point in the history
Allow overwrite if forced
  • Loading branch information
johnlees authored Oct 13, 2023
2 parents eaad3db + 54ffc03 commit 14bcd2b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/azure_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Get current date
id: date
run: echo "date=$(date +%Y-%m-%d)" >> "${GITHUB_OUTPUT}"
- name: Install Conda environment from environment.yml
uses: mamba-org/setup-micromamba@v1
with:
micromamba-version: '1.4.6-0'
environment-file: environment.yml
cache-environment: true
# persist on the same day.
cache-environment-key: environment-${{ steps.date.outputs.date }}
cache-downloads-key: downloads-${{ steps.date.outputs.date }}
- name: Install and run_test.py
shell: bash -l {0}
run: |
Expand Down
11 changes: 6 additions & 5 deletions PopPUNK/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def assign_query(dbFuncs,
constructDatabase = dbFuncs['constructDatabase']
readDBParams = dbFuncs['readDBParams']

if ref_db == output:
sys.stderr.write("--output and --ref-db must be different to "
if ref_db == output and overwrite == False:
sys.stderr.write("--output and --db must be different to "
"prevent overwrite.\n")
sys.exit(1)

Expand Down Expand Up @@ -386,8 +386,8 @@ def assign_query_hdf5(dbFuncs,
readDBParams = dbFuncs['readDBParams']
getSeqsInDb = dbFuncs['getSeqsInDb']

if ref_db == output:
sys.stderr.write("--output and --ref-db must be different to "
if ref_db == output and overwrite == False:
sys.stderr.write("--output and --db must be different to "
"prevent overwrite.\n")
sys.exit(1)
if (update_db and not distances):
Expand Down Expand Up @@ -509,8 +509,9 @@ def assign_query_hdf5(dbFuncs,

n_vertices = len(get_vertex_list(genomeNetwork, use_gpu = gpu_graph))
if n_vertices != len(rNames):
sys.stderr.write(f"There are {n_vertices} vertices in the network but {len(rNames)} reference names supplied; " + \
sys.stderr.write(f"ERROR: There are {n_vertices} vertices in the network but {len(rNames)} reference names supplied; " + \
"please check the '--model-dir' variable is pointing to the correct directory\n")
sys.exit(1)

if model.type == 'lineage':
# Assign lineages by calculating query-query information
Expand Down
15 changes: 10 additions & 5 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,11 +536,16 @@ def network_to_edges(prev_G_fn, rlist, adding_qq_dists = False,
source_ids = old_source_ids
target_ids = old_target_ids
else:
# Update IDs to new versions
old_id_indices = [rlist.index(x) for x in old_ids]
# translate to indices
source_ids = [old_id_indices[x] for x in old_source_ids]
target_ids = [old_id_indices[x] for x in old_target_ids]
try:
# Update IDs to new versions
old_id_indices = [rlist.index(x) for x in old_ids]
# translate to indices
source_ids = [old_id_indices[x] for x in old_source_ids]
target_ids = [old_id_indices[x] for x in old_target_ids]
except ValueError:
sys.stderr.write(f"Network size mismatch. Previous network nodes: {max(old_id_indices)}."
f"New network nodes: {max(old_source_ids.a)}/{max(old_target_ids.a)}\n")
sys.exit(1)

# return values
if weights:
Expand Down

0 comments on commit 14bcd2b

Please sign in to comment.