Skip to content

Commit

Permalink
Search parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Hampus Linander committed Oct 1, 2024
1 parent 487b0c0 commit 14410da
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 39 deletions.
197 changes: 173 additions & 24 deletions experiments/weather/ring_windows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,43 @@
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"id": "7384006f-a3fc-4501-820d-b3852f4e3239",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[db] Connection to alvis2:5431\n"
"[Compute environment] Could not load env.py: \n",
"Traceback (most recent call last):\n",
" File \"/Users/hampus/projects/equivariant-posteriors/lib/compute_env.py\", line 17, in env\n",
" import env\n",
"ModuleNotFoundError: No module named 'env'\n",
"\n",
"[Compute environment] Using defaults\n",
"[Compute environment] paths: \n",
"[Paths] checkpoints: checkpoints (/Users/hampus/projects/equivariant-posteriors/experiments/weather/checkpoints)\n",
"[Paths] locks: locks (/Users/hampus/projects/equivariant-posteriors/experiments/weather/locks)\n",
"[Paths] distributed_requests: distributed_requests (/Users/hampus/projects/equivariant-posteriors/experiments/weather/distributed_requests)\n",
"[Paths] artifacts: artifacts (/Users/hampus/projects/equivariant-posteriors/experiments/weather/artifacts)\n",
"[Paths] datasets: datasets (/Users/hampus/projects/equivariant-posteriors/experiments/weather/datasets)\n",
"[Compute environment] postgres_host: localhost\n",
"[Compute environment] postgres_port: 5432\n",
"[Compute environment] postgres_password: postgres\n"
]
},
{
"ename": "OperationalError",
"evalue": "connection failed: could not receive data from server: Connection refused",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mOperationalError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 52\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m train_run\n\u001b[1;32m 51\u001b[0m config \u001b[38;5;241m=\u001b[39m create_config(\u001b[38;5;241m100\u001b[39m, \u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m---> 52\u001b[0m \u001b[43madd_train_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 53\u001b[0m result_path \u001b[38;5;241m=\u001b[39m prepare_results(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mring_windows\u001b[39m\u001b[38;5;124m\"\u001b[39m, config)\n\u001b[1;32m 54\u001b[0m setup_psql()\n",
"File \u001b[0;32m~/projects/equivariant-posteriors/lib/render_psql.py:463\u001b[0m, in \u001b[0;36madd_train_run\u001b[0;34m(train_run)\u001b[0m\n\u001b[1;32m 462\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21madd_train_run\u001b[39m(train_run):\n\u001b[0;32m--> 463\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43mpsycopg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnect\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 464\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdbname=equiv user=postgres password=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43menv\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpostgres_password\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 465\u001b[0m \u001b[43m \u001b[49m\u001b[43mhost\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43menv\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpostgres_host\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 466\u001b[0m \u001b[43m \u001b[49m\u001b[43mport\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpostgres_port\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 467\u001b[0m \u001b[43m \u001b[49m\u001b[43mautocommit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 468\u001b[0m \u001b[43m \u001b[49m\u001b[43mprepare_threshold\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 469\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m conn:\n\u001b[1;32m 470\u001b[0m insert_or_update_train_run(conn, train_run)\n\u001b[1;32m 471\u001b[0m conn\u001b[38;5;241m.\u001b[39mcommit()\n",
"File \u001b[0;32m/nix/store/ayax3wxh2nyji9rbnm2gnprnyrkqrpks-python3-3.11.9-env/lib/python3.11/site-packages/psycopg/connection.py:749\u001b[0m, in \u001b[0;36mConnection.connect\u001b[0;34m(cls, conninfo, autocommit, prepare_threshold, row_factory, cursor_factory, context, **kwargs)\u001b[0m\n\u001b[1;32m 747\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m rv:\n\u001b[1;32m 748\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m last_ex\n\u001b[0;32m--> 749\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m last_ex\u001b[38;5;241m.\u001b[39mwith_traceback(\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 751\u001b[0m rv\u001b[38;5;241m.\u001b[39m_autocommit \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mbool\u001b[39m(autocommit)\n\u001b[1;32m 752\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m row_factory:\n",
"\u001b[0;31mOperationalError\u001b[0m: connection failed: could not receive data from server: Connection refused"
]
}
],
Expand Down Expand Up @@ -74,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": 256,
"execution_count": 2,
"id": "05d3ed22-3bb8-41bf-999e-855dccbbd975",
"metadata": {},
"outputs": [
Expand All @@ -84,7 +112,7 @@
"192"
]
},
"execution_count": 256,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -101,7 +129,7 @@
},
{
"cell_type": "code",
"execution_count": 257,
"execution_count": 3,
"id": "284732c2-f52d-4925-958a-d3b45bab6edb",
"metadata": {},
"outputs": [],
Expand All @@ -111,7 +139,7 @@
},
{
"cell_type": "code",
"execution_count": 258,
"execution_count": 4,
"id": "7a6a3f80-e281-4a00-b0c7-bad6f7bdf3b7",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -147,7 +175,7 @@
},
{
"cell_type": "code",
"execution_count": 259,
"execution_count": 5,
"id": "cc9343e8-c1d3-4071-96e1-0573e9ae202e",
"metadata": {},
"outputs": [
Expand All @@ -160,7 +188,7 @@
" [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]]"
]
},
"execution_count": 259,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -171,7 +199,7 @@
},
{
"cell_type": "code",
"execution_count": 261,
"execution_count": 6,
"id": "6acdc006-7178-464a-90ca-f551a2db1b13",
"metadata": {},
"outputs": [
Expand All @@ -181,7 +209,7 @@
"7"
]
},
"execution_count": 261,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -192,17 +220,35 @@
},
{
"cell_type": "code",
"execution_count": 246,
"execution_count": 7,
"id": "1bb79ee2-3213-4103-984e-ee4eae0cf2eb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[[11, 10, 9, 8]]"
"[[167,\n",
" 166,\n",
" 165,\n",
" 164,\n",
" 163,\n",
" 162,\n",
" 161,\n",
" 160,\n",
" 159,\n",
" 158,\n",
" 157,\n",
" 156,\n",
" 155,\n",
" 154,\n",
" 153,\n",
" 152],\n",
" [179, 178, 177, 176, 175, 174, 173, 172, 171, 170, 169, 168],\n",
" [187, 186, 185, 184, 183, 182, 181, 180],\n",
" [191, 190, 189, 188]]"
]
},
"execution_count": 246,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -213,10 +259,41 @@
},
{
"cell_type": "code",
"execution_count": 247,
"execution_count": 14,
"id": "e179ea45-b0c9-498f-a5e9-012614cce159",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(4,)\n",
"(8,)\n",
"(12,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(8,)\n",
"(12,)\n",
"(8,)\n",
"(4,)\n"
]
}
],
"source": [
"all_windows = north_idxs + north_eq_idxs + south_eq_idxs + south_idxs\n",
"colors = np.arange(len(all_windows))\n",
Expand All @@ -227,22 +304,37 @@
" nest_idxs = chp.ring2nest(NSIDE, window)\n",
" for sub_idx in range(n_sub_windows):\n",
" sub_idxs = nest_idxs[sub_idx::n_sub_windows]\n",
" #print(sub_idxs.shape)\n",
" hp[sub_idxs] = float(colors[(sub_idx + idx) % len(all_windows)])#float(2*idx % len(all_windows)) + 1"
]
},
{
"cell_type": "code",
"execution_count": 248,
"execution_count": 9,
"id": "b3302adb-1ec1-4eb0-a378-aca00380c7a3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([2., 2., 2., 2., 1., 1., 1., 1., 0., 0., 0., 0.], dtype=float32)"
"array([ 4., 7., 4., 2., 7., 2., 10., 13., 7., 2., 10., 13., 13.,\n",
" 8., 8., 0., 4., 7., 4., 2., 7., 2., 10., 13., 7., 2.,\n",
" 10., 13., 13., 8., 8., 0., 4., 7., 4., 2., 7., 2., 10.,\n",
" 13., 7., 2., 10., 13., 13., 8., 8., 0., 4., 7., 4., 2.,\n",
" 7., 2., 10., 13., 7., 2., 10., 13., 13., 8., 8., 0., 5.,\n",
" 1., 5., 14., 1., 14., 11., 11., 1., 14., 11., 11., 4., 7.,\n",
" 4., 2., 5., 1., 5., 14., 1., 14., 11., 11., 1., 14., 11.,\n",
" 11., 4., 7., 4., 2., 5., 1., 5., 14., 1., 14., 11., 11.,\n",
" 1., 14., 11., 11., 4., 7., 4., 2., 5., 1., 5., 14., 1.,\n",
" 14., 11., 11., 1., 14., 11., 11., 4., 7., 4., 2., 6., 9.,\n",
" 9., 3., 3., 12., 3., 12., 3., 12., 3., 12., 5., 1., 5.,\n",
" 14., 6., 9., 9., 3., 3., 12., 3., 12., 3., 12., 3., 12.,\n",
" 5., 1., 5., 14., 6., 9., 9., 3., 3., 12., 3., 12., 3.,\n",
" 12., 3., 12., 5., 1., 5., 14., 6., 9., 9., 3., 3., 12.,\n",
" 3., 12., 3., 12., 3., 12., 5., 1., 5., 14.], dtype=float32)"
]
},
"execution_count": 248,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -299,7 +391,7 @@
},
{
"cell_type": "code",
"execution_count": 188,
"execution_count": 10,
"id": "9c0a1257-ae87-485f-a804-db79ce5dcbf3",
"metadata": {},
"outputs": [
Expand All @@ -309,7 +401,7 @@
"16.0"
]
},
"execution_count": 188,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -321,7 +413,7 @@
},
{
"cell_type": "code",
"execution_count": 189,
"execution_count": 12,
"id": "a926c83e-2531-48d4-b163-d4775e486c99",
"metadata": {},
"outputs": [
Expand All @@ -331,7 +423,7 @@
"192"
]
},
"execution_count": 189,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -340,10 +432,67 @@
"sum([4, 8, 12, 16, 16, 16, 16, 16, 16, 16, 16, 16, 12, 8, 4])"
]
},
{
"cell_type": "markdown",
"id": "239b40e0-28b5-4051-84bf-776e608dc3a5",
"metadata": {},
"source": [
"# With depth"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "2980c220-0488-44db-b97d-c6b0f8aa0554",
"metadata": {},
"outputs": [],
"source": [
"def get_isolatitude_windows_hp(nside):\n",
" polar_idx = list(range(0, nside))\n",
" current_idx = 0\n",
" north_idxs = []\n",
" north_eq_idxs = []\n",
" south_eq_idxs = []\n",
" south_idxs = []\n",
" for window_idx in polar_idx:\n",
" north_idxs.append([ current_idx + i for i in range(4 * (window_idx + 1))])\n",
" current_idx += 4 * (window_idx + 1)\n",
" \n",
" for window_idx in range(nside):\n",
" north_eq_idxs.append([current_idx + i for i in range(4*nside)])\n",
" current_idx += 4*nside\n",
" \n",
" for window_idx in range(nside - 1):\n",
" south_eq_idxs.append([current_idx + i for i in range(4*nside)])\n",
" current_idx += 4*nside\n",
" \n",
" # nside 2, 0 -> 0\n",
" \n",
" # nside 3, 0 -> 1\n",
" # nside 3, 1 -> 0\n",
" for window in reversed(north_idxs):\n",
" south_idxs.append([n_pixels - 1 - idx for idx in window])\n",
"\n",
" return north_idxs + north_eq_idxs + south_eq_idxs + south_idxs"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "47e683f8-4b1b-434e-b83d-2e04e7725178",
"metadata": {},
"outputs": [],
"source": [
"test = torch.zeros((1, 13\n",
"hp_windows = get_isolatitude_windows_hp(4)\n",
"for hp_window in hp_windows:\n",
" for depth in zip("
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0626ad70-a829-4816-a044-ad44a8425609",
"id": "44741a8a-4979-409b-abb7-06f203e085db",
"metadata": {},
"outputs": [],
"source": []
Expand All @@ -365,7 +514,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 14410da

Please sign in to comment.