diff --git a/ct-foundation/CT_Foundation_Demo.ipynb b/ct-foundation/CT_Foundation_Demo.ipynb index 8b9791a..0e79243 100644 --- a/ct-foundation/CT_Foundation_Demo.ipynb +++ b/ct-foundation/CT_Foundation_Demo.ipynb @@ -68,7 +68,10 @@ "It's important to use evaluation datasets\n", "that reflect the expected distribution of images and patients you wish to use any downstream models on.\n", "\n", - "This means that the best way to determine if this API is right for you is to try it with data that would be used for the downstream task you're interested in." + "This means that the best way to determine if this API is right for you is to try it with data that would be used for the downstream task you're interested in.\n", + "\n", + "**Note**: If you want to jump to training a model with embeddings, you can\n", + "scroll down to [Train a model with the embeddings from NLST](#train-nlst)" ] }, { @@ -210,7 +213,6 @@ "from requests_toolbelt.multipart import decoder\n", "from google.cloud import storage\n", "import tensorflow as tf\n", - "#import tensorflow_models as tfm\n", "import matplotlib.pyplot as plt" ] }, @@ -632,7 +634,6 @@ }, "outputs": [], "source": [ - "\n", "# @title Create a Session via a token.\n", "creds = credentials.Credentials(TOKEN)\n", "session = create_authorized_session(creds)" @@ -783,7 +784,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "qzal615K4rDv" }, "outputs": [], @@ -819,48 +819,50 @@ "retry_policy = Retry(predicate=_is_retryable)\n", "\n", "\n", - "def get_ct_embeddings(mytoken: str, dicom_url: str) -\u003e np.ndarray:\n", + "def get_ct_embeddings(mytoken: str, dicom_urls: List[str]) -\u003e List[Any]:\n", " \"\"\"Calls the API to collect the embeddings from a given volume.\n", "\n", " Args:\n", " mytoken: The token to access the DICOM volume and API.\n", - " dicom_url: The Series-level DICOM URL to the CT volume.\n", + " dicom_urls: The list of Series-level DICOM URL to the CT volumes.\n", "\n", " Returns:\n", - " The embeddings generated by the service. Differences in Vertex\n", - " end-point configurations may change the return type. The caller is\n", + " The list of embeddings or errors generated by the service. Differences in\n", + " Vertex end-point configurations may change the return type. The caller is\n", " responsible for interpreting this value and extracting the requisite\n", " data.\n", "\n", - " Raises:\n", - " ValueError: If the API call fails.\n", " \"\"\"\n", " api_client = aiplatform.gapic.PredictionServiceClient(\n", " client_options=ClientOptions(api_endpoint=_API_ENDPOINT)\n", " )\n", " endpoint = api_client.endpoint_path(\n", - " project='hai-cd3-use-case-dev', location='use-west1', endpoint=300\n", + " project='hai-cd3-foundations', location='us-central1', endpoint=300\n", " )\n", "\n", " # Create a single instance to send - you can send up to 5.\n", - " instance ={\"dicom_path\": dicom_url,\n", - " \"bearer_token\": mytoken}\n", + " instances = []\n", + " for dicom_url in dicom_urls:\n", + " instances.append({\"dicom_path\": dicom_url, \"bearer_token\": mytoken})\n", "\n", " response = api_client.predict(\n", - " endpoint=endpoint, instances=[instance], retry=retry_policy, timeout=60\n", + " endpoint=endpoint, instances=[instance], retry=retry_policy, timeout=600\n", " )\n", - " assert len(response.predictions) == 1\n", + " assert len(response.predictions) == len(dicom_urls)\n", " assert len(response.predictions[0]) == 3\n", " # You can get the model version for this prediction\n", " # response.predictions[0]['model_version']\n", "\n", " # Check for no errors\n", - " if response.predictions[0]['error_response'] is not None:\n", - " raise ValueError(response.predictions[0]['error_response'])\n", - " embeddings = np.array(\n", - " response.predictions[0]['embedding_result']['embedding'],\n", - " dtype=np.float32)\n", - " assert embeddings.shape == (1408,), 'Unexpected embeddings shape recieved.'\n", + " responses = []\n", + " for a_prediction in response.predictions:\n", + " if a_prediction['error_response'] is not None:\n", + " responses.append(a_prediction['error_response'])\n", + " else:\n", + " embeddings = np.array(a_prediction['embedding_result']['embedding'],\n", + " dtype=np.float32)\n", + " assert embeddings.shape == (1408,), 'Unexpected embeddings shape.'\n", + " responses.append(embeddings)\n", "\n", " return embeddings" ] @@ -889,6 +891,7 @@ "id": "8JQRxrXiUUgv" }, "source": [ + "\u003ca name=\"train-nlst\"\u003e\u003c/a\u003e\n", "# Train a model with the embeddings from NLST\n", "\n", "Here we have a full set of embeddings from the NLST dataset that you can download and train a cancer detection model."