From 967684304b9fb0d977289bacadedb3ab407214cd Mon Sep 17 00:00:00 2001 From: Simran Arora Date: Thu, 8 Feb 2024 20:58:30 +0000 Subject: [PATCH] modify general ar synthetic --- notebooks/test_data.ipynb | 184 +++++++++++++++++++++++++++++ zoology/data/associative_recall.py | 2 +- 2 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 notebooks/test_data.ipynb mode change 100644 => 100755 zoology/data/associative_recall.py diff --git a/notebooks/test_data.ipynb b/notebooks/test_data.ipynb new file mode 100644 index 0000000..efd10f9 --- /dev/null +++ b/notebooks/test_data.ipynb @@ -0,0 +1,184 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "# add autoreload\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"/code/zoology/zoology/data\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(583) tensor(-100)\n", + "tensor(5999) tensor(-100)\n", + "tensor(1962) tensor(-100)\n", + "tensor(5773) tensor(-100)\n", + "tensor(1958) tensor(-100)\n", + "tensor(6954) tensor(-100)\n", + "tensor(3194) tensor(-100)\n", + "tensor(6629) tensor(-100)\n", + "tensor(1958) tensor(6954)\n", + "tensor(780) tensor(-100)\n", + "tensor(6806) tensor(-100)\n", + "tensor(143) tensor(-100)\n", + "tensor(1729) tensor(-100)\n", + "tensor(6368) tensor(-100)\n", + "tensor(6014) tensor(-100)\n", + "tensor(4482) tensor(-100)\n", + "tensor(2079) tensor(-100)\n", + "tensor(305) tensor(-100)\n", + "tensor(3111) tensor(-100)\n", + "tensor(253) tensor(-100)\n", + "tensor(4161) tensor(-100)\n", + "tensor(516) tensor(-100)\n", + "tensor(1939) tensor(-100)\n", + "tensor(4073) tensor(-100)\n", + "tensor(1458) tensor(-100)\n", + "tensor(5804) tensor(-100)\n", + "tensor(3292) tensor(-100)\n", + "tensor(1369) tensor(-100)\n", + "tensor(3024) tensor(-100)\n", + "tensor(7556) tensor(-100)\n" + ] + } + ], + "source": [ + "from associative_recall import associative_recall\n", + "data = associative_recall()\n", + "x= data.train_inputs[0]\n", + "y= data.train_labels[0]\n", + "\n", + "for i, (_x, _y) in enumerate(zip(x, y)):\n", + " if i == 30:\n", + " break\n", + " print(_x, _y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(583) tensor(-100)\n", + "tensor(4898) tensor(-100)\n", + "tensor(1962) tensor(-100)\n", + "tensor(4671) tensor(-100)\n", + "tensor(1958) tensor(-100)\n", + "tensor(7947) tensor(-100)\n", + "tensor(3194) tensor(-100)\n", + "tensor(6084) tensor(-100)\n", + "tensor(1962) tensor(4671)\n", + "tensor(5096) tensor(-100)\n", + "tensor(5725) tensor(-100)\n", + "tensor(6249) tensor(-100)\n", + "tensor(6204) tensor(-100)\n", + "tensor(2841) tensor(-100)\n", + "tensor(3194) tensor(6084)\n", + "tensor(401) tensor(-100)\n", + "tensor(6296) tensor(-100)\n", + "tensor(881) tensor(-100)\n", + "tensor(7629) tensor(-100)\n", + "tensor(2311) tensor(-100)\n", + "tensor(3502) tensor(-100)\n", + "tensor(7825) tensor(-100)\n", + "tensor(3207) tensor(-100)\n", + "tensor(910) tensor(-100)\n", + "tensor(8025) tensor(-100)\n", + "tensor(1454) tensor(-100)\n", + "tensor(1236) tensor(-100)\n", + "tensor(3014) tensor(-100)\n", + "tensor(5794) tensor(-100)\n", + "tensor(6338) tensor(-100)\n" + ] + } + ], + "source": [ + "from associative_recall import multiquery_ar\n", + "data = multiquery_ar()\n", + "x= data.train_inputs[0]\n", + "y= data.train_labels[0]\n", + "\n", + "for i, (_x, _y) in enumerate(zip(x, y)):\n", + " if i == 30:\n", + " break\n", + " print(_x, _y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/zoology/data/associative_recall.py b/zoology/data/associative_recall.py old mode 100644 new mode 100755 index 32cdd3e..66361bb --- a/zoology/data/associative_recall.py +++ b/zoology/data/associative_recall.py @@ -154,7 +154,7 @@ def _ar( inputs[rows, query_pos] = queries targets[rows, query_pos] = labels - inputs, targets = torch.tensor(inputs[:, :-1]), torch.tensor(targets[:, 1:]) + inputs, targets = torch.tensor(inputs[:, :-1]), torch.tensor(targets[:, :-1]) # replace all the 0 with random values if random_non_queries: