Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689942556
  • Loading branch information
Google Research authored and rbpilgrim committed Oct 25, 2024
1 parent b11039b commit 9c1bb40
Showing 1 changed file with 24 additions and 21 deletions.
45 changes: 24 additions & 21 deletions ct-foundation/CT_Foundation_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@
"It's important to use evaluation datasets\n",
"that reflect the expected distribution of images and patients you wish to use any downstream models on.\n",
"\n",
"This means that the best way to determine if this API is right for you is to try it with data that would be used for the downstream task you're interested in."
"This means that the best way to determine if this API is right for you is to try it with data that would be used for the downstream task you're interested in.\n",
"\n",
"**Note**: If you want to jump to training a model with embeddings, you can\n",
"scroll down to [Train a model with the embeddings from NLST](#train-nlst)"
]
},
{
Expand Down Expand Up @@ -210,7 +213,6 @@
"from requests_toolbelt.multipart import decoder\n",
"from google.cloud import storage\n",
"import tensorflow as tf\n",
"#import tensorflow_models as tfm\n",
"import matplotlib.pyplot as plt"
]
},
Expand Down Expand Up @@ -632,7 +634,6 @@
},
"outputs": [],
"source": [
"\n",
"# @title Create a Session via a token.\n",
"creds = credentials.Credentials(TOKEN)\n",
"session = create_authorized_session(creds)"
Expand Down Expand Up @@ -783,7 +784,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "qzal615K4rDv"
},
"outputs": [],
Expand Down Expand Up @@ -819,48 +819,50 @@
"retry_policy = Retry(predicate=_is_retryable)\n",
"\n",
"\n",
"def get_ct_embeddings(mytoken: str, dicom_url: str) -\u003e np.ndarray:\n",
"def get_ct_embeddings(mytoken: str, dicom_urls: List[str]) -\u003e List[Any]:\n",
" \"\"\"Calls the API to collect the embeddings from a given volume.\n",
"\n",
" Args:\n",
" mytoken: The token to access the DICOM volume and API.\n",
" dicom_url: The Series-level DICOM URL to the CT volume.\n",
" dicom_urls: The list of Series-level DICOM URL to the CT volumes.\n",
"\n",
" Returns:\n",
" The embeddings generated by the service. Differences in Vertex\n",
" end-point configurations may change the return type. The caller is\n",
" The list of embeddings or errors generated by the service. Differences in\n",
" Vertex end-point configurations may change the return type. The caller is\n",
" responsible for interpreting this value and extracting the requisite\n",
" data.\n",
"\n",
" Raises:\n",
" ValueError: If the API call fails.\n",
" \"\"\"\n",
" api_client = aiplatform.gapic.PredictionServiceClient(\n",
" client_options=ClientOptions(api_endpoint=_API_ENDPOINT)\n",
" )\n",
" endpoint = api_client.endpoint_path(\n",
" project='hai-cd3-use-case-dev', location='use-west1', endpoint=300\n",
" project='hai-cd3-foundations', location='us-central1', endpoint=300\n",
" )\n",
"\n",
" # Create a single instance to send - you can send up to 5.\n",
" instance ={\"dicom_path\": dicom_url,\n",
" \"bearer_token\": mytoken}\n",
" instances = []\n",
" for dicom_url in dicom_urls:\n",
" instances.append({\"dicom_path\": dicom_url, \"bearer_token\": mytoken})\n",
"\n",
" response = api_client.predict(\n",
" endpoint=endpoint, instances=[instance], retry=retry_policy, timeout=60\n",
" endpoint=endpoint, instances=[instance], retry=retry_policy, timeout=600\n",
" )\n",
" assert len(response.predictions) == 1\n",
" assert len(response.predictions) == len(dicom_urls)\n",
" assert len(response.predictions[0]) == 3\n",
" # You can get the model version for this prediction\n",
" # response.predictions[0]['model_version']\n",
"\n",
" # Check for no errors\n",
" if response.predictions[0]['error_response'] is not None:\n",
" raise ValueError(response.predictions[0]['error_response'])\n",
" embeddings = np.array(\n",
" response.predictions[0]['embedding_result']['embedding'],\n",
" dtype=np.float32)\n",
" assert embeddings.shape == (1408,), 'Unexpected embeddings shape recieved.'\n",
" responses = []\n",
" for a_prediction in response.predictions:\n",
" if a_prediction['error_response'] is not None:\n",
" responses.append(a_prediction['error_response'])\n",
" else:\n",
" embeddings = np.array(a_prediction['embedding_result']['embedding'],\n",
" dtype=np.float32)\n",
" assert embeddings.shape == (1408,), 'Unexpected embeddings shape.'\n",
" responses.append(embeddings)\n",
"\n",
" return embeddings"
]
Expand Down Expand Up @@ -889,6 +891,7 @@
"id": "8JQRxrXiUUgv"
},
"source": [
"\u003ca name=\"train-nlst\"\u003e\u003c/a\u003e\n",
"# Train a model with the embeddings from NLST\n",
"\n",
"Here we have a full set of embeddings from the NLST dataset that you can download and train a cancer detection model."
Expand Down

0 comments on commit 9c1bb40

Please sign in to comment.