-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathembedding.py
67 lines (54 loc) · 2.21 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
from openai import OpenAI
from google.cloud import storage
import json
openai_api_key = os.getenv('OPENAI_API_KEY')
openai = OpenAI(api_key=openai_api_key)
# Google Cloud Storage クライアントを初期化
storage_client = storage.Client()
bucket_name = 'your-bucket-name' # バケット名を設定
def download_text_from_storage(bucket_name, source_blob_name):
"""Download a text file from the bucket."""
try:
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(source_blob_name)
return blob.download_as_text()
except Exception as e:
print(f"Failed to download file: {e}")
raise
def create_embedding(input_text):
response = openai.embeddings.create(
model='text-embedding-3-large',
input=input_text
)
vec = response['data'][0]['embedding']
return vec
def upload_blob(bucket_name, source_stream, destination_blob_name, content_type='text'):
"""Uploads a file to the bucket from a byte stream."""
try:
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(destination_blob_name)
blob.upload_from_string(source_stream, content_type=content_type)
public_url = f"https://storage.googleapis.com/{bucket_name}/{destination_blob_name}"
return public_url
except Exception as e:
print(f"Failed to upload file: {e}")
raise
def embedding_from_storage(bucket_name):
try:
source_blob_name = 'embedding_input.txt'
destination_blob_name = 'embedding_data.json'
# Cloud Storageからテキストデータをダウンロード
input_text = download_text_from_storage(bucket_name, source_blob_name)
# ベクトルデータを生成
embedding_vector = create_embedding(input_text)
# ベクトルデータをJSON形式に変換
embedding_json = json.dumps(embedding_vector)
# Cloud Storageにアップロードして、公開URLを取得
public_url = upload_blob(bucket_name, embedding_json, destination_blob_name)
return public_url
except Exception as e:
print(f"Error: {e}")
raise