Skip to content

Commit

Permalink
modify general ar synthetic
Browse files Browse the repository at this point in the history
  • Loading branch information
simran-arora committed Feb 8, 2024
1 parent 240196e commit 9676843
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 1 deletion.
184 changes: 184 additions & 0 deletions notebooks/test_data.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"# add autoreload\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"/code/zoology/zoology/data\")\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(583) tensor(-100)\n",
"tensor(5999) tensor(-100)\n",
"tensor(1962) tensor(-100)\n",
"tensor(5773) tensor(-100)\n",
"tensor(1958) tensor(-100)\n",
"tensor(6954) tensor(-100)\n",
"tensor(3194) tensor(-100)\n",
"tensor(6629) tensor(-100)\n",
"tensor(1958) tensor(6954)\n",
"tensor(780) tensor(-100)\n",
"tensor(6806) tensor(-100)\n",
"tensor(143) tensor(-100)\n",
"tensor(1729) tensor(-100)\n",
"tensor(6368) tensor(-100)\n",
"tensor(6014) tensor(-100)\n",
"tensor(4482) tensor(-100)\n",
"tensor(2079) tensor(-100)\n",
"tensor(305) tensor(-100)\n",
"tensor(3111) tensor(-100)\n",
"tensor(253) tensor(-100)\n",
"tensor(4161) tensor(-100)\n",
"tensor(516) tensor(-100)\n",
"tensor(1939) tensor(-100)\n",
"tensor(4073) tensor(-100)\n",
"tensor(1458) tensor(-100)\n",
"tensor(5804) tensor(-100)\n",
"tensor(3292) tensor(-100)\n",
"tensor(1369) tensor(-100)\n",
"tensor(3024) tensor(-100)\n",
"tensor(7556) tensor(-100)\n"
]
}
],
"source": [
"from associative_recall import associative_recall\n",
"data = associative_recall()\n",
"x= data.train_inputs[0]\n",
"y= data.train_labels[0]\n",
"\n",
"for i, (_x, _y) in enumerate(zip(x, y)):\n",
" if i == 30:\n",
" break\n",
" print(_x, _y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(583) tensor(-100)\n",
"tensor(4898) tensor(-100)\n",
"tensor(1962) tensor(-100)\n",
"tensor(4671) tensor(-100)\n",
"tensor(1958) tensor(-100)\n",
"tensor(7947) tensor(-100)\n",
"tensor(3194) tensor(-100)\n",
"tensor(6084) tensor(-100)\n",
"tensor(1962) tensor(4671)\n",
"tensor(5096) tensor(-100)\n",
"tensor(5725) tensor(-100)\n",
"tensor(6249) tensor(-100)\n",
"tensor(6204) tensor(-100)\n",
"tensor(2841) tensor(-100)\n",
"tensor(3194) tensor(6084)\n",
"tensor(401) tensor(-100)\n",
"tensor(6296) tensor(-100)\n",
"tensor(881) tensor(-100)\n",
"tensor(7629) tensor(-100)\n",
"tensor(2311) tensor(-100)\n",
"tensor(3502) tensor(-100)\n",
"tensor(7825) tensor(-100)\n",
"tensor(3207) tensor(-100)\n",
"tensor(910) tensor(-100)\n",
"tensor(8025) tensor(-100)\n",
"tensor(1454) tensor(-100)\n",
"tensor(1236) tensor(-100)\n",
"tensor(3014) tensor(-100)\n",
"tensor(5794) tensor(-100)\n",
"tensor(6338) tensor(-100)\n"
]
}
],
"source": [
"from associative_recall import multiquery_ar\n",
"data = multiquery_ar()\n",
"x= data.train_inputs[0]\n",
"y= data.train_labels[0]\n",
"\n",
"for i, (_x, _y) in enumerate(zip(x, y)):\n",
" if i == 30:\n",
" break\n",
" print(_x, _y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion zoology/data/associative_recall.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _ar(
inputs[rows, query_pos] = queries
targets[rows, query_pos] = labels

inputs, targets = torch.tensor(inputs[:, :-1]), torch.tensor(targets[:, 1:])
inputs, targets = torch.tensor(inputs[:, :-1]), torch.tensor(targets[:, :-1])

# replace all the 0 with random values
if random_non_queries:
Expand Down

0 comments on commit 9676843

Please sign in to comment.