Upload 12 files
Browse files- openai_fake_data_generator.py +20 -29
- presidio_helpers.py +2 -8
- presidio_streamlit.py +10 -7
openai_fake_data_generator.py
CHANGED
|
@@ -2,54 +2,45 @@ from collections import namedtuple
|
|
| 2 |
from typing import Optional
|
| 3 |
|
| 4 |
import openai
|
|
|
|
| 5 |
import logging
|
| 6 |
|
| 7 |
logger = logging.getLogger("presidio-streamlit")
|
| 8 |
|
| 9 |
OpenAIParams = namedtuple(
|
| 10 |
"open_ai_params",
|
| 11 |
-
["openai_key", "model", "api_base", "
|
| 12 |
)
|
| 13 |
|
| 14 |
|
| 15 |
-
def set_openai_params(openai_params: OpenAIParams):
|
| 16 |
-
"""Set the OpenAI API key.
|
| 17 |
-
:param openai_params: OpenAIParams object with the following fields: key, model, api version, deployment_name,
|
| 18 |
-
The latter only relate to Azure OpenAI deployments.
|
| 19 |
-
"""
|
| 20 |
-
openai.api_key = openai_params.openai_key
|
| 21 |
-
openai.api_version = openai_params.api_version
|
| 22 |
-
if openai_params.api_base:
|
| 23 |
-
openai.api_base = openai_params.api_base
|
| 24 |
-
openai.api_type = openai_params.api_type
|
| 25 |
-
|
| 26 |
-
|
| 27 |
def call_completion_model(
|
| 28 |
prompt: str,
|
| 29 |
-
|
| 30 |
-
max_tokens: int =
|
| 31 |
-
deployment_id: Optional[str] = None,
|
| 32 |
) -> str:
|
| 33 |
"""Creates a request for the OpenAI Completion service and returns the response.
|
| 34 |
|
| 35 |
:param prompt: The prompt for the completion model
|
| 36 |
-
:param
|
| 37 |
-
:param max_tokens:
|
| 38 |
-
:param deployment_id: Azure OpenAI deployment ID
|
| 39 |
"""
|
| 40 |
-
if
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
)
|
| 47 |
else:
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
return response
|
| 53 |
|
| 54 |
|
| 55 |
def create_prompt(anonymized_text: str) -> str:
|
|
|
|
| 2 |
from typing import Optional
|
| 3 |
|
| 4 |
import openai
|
| 5 |
+
from openai import OpenAI, AzureOpenAI
|
| 6 |
import logging
|
| 7 |
|
| 8 |
logger = logging.getLogger("presidio-streamlit")
|
| 9 |
|
| 10 |
OpenAIParams = namedtuple(
|
| 11 |
"open_ai_params",
|
| 12 |
+
["openai_key", "model", "api_base", "deployment_id", "api_version", "api_type"],
|
| 13 |
)
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def call_completion_model(
|
| 17 |
prompt: str,
|
| 18 |
+
openai_params: OpenAIParams,
|
| 19 |
+
max_tokens: Optional[int] = 256,
|
|
|
|
| 20 |
) -> str:
|
| 21 |
"""Creates a request for the OpenAI Completion service and returns the response.
|
| 22 |
|
| 23 |
:param prompt: The prompt for the completion model
|
| 24 |
+
:param openai_params: OpenAI parameters for the completion model
|
| 25 |
+
:param max_tokens: The maximum number of tokens to generate.
|
|
|
|
| 26 |
"""
|
| 27 |
+
if openai_params.api_type.lower() == "azure":
|
| 28 |
+
client = AzureOpenAI(
|
| 29 |
+
api_version=openai_params.api_version,
|
| 30 |
+
api_key=openai_params.openai_key,
|
| 31 |
+
azure_endpoint=openai_params.api_base,
|
| 32 |
+
azure_deployment=openai_params.deployment_id,
|
| 33 |
)
|
| 34 |
else:
|
| 35 |
+
client = OpenAI(api_key=openai_params.openai_key)
|
| 36 |
+
|
| 37 |
+
response = client.completions.create(
|
| 38 |
+
model=openai_params.model,
|
| 39 |
+
prompt=prompt,
|
| 40 |
+
max_tokens=max_tokens,
|
| 41 |
+
)
|
| 42 |
|
| 43 |
+
return response.choices[0].text.strip()
|
| 44 |
|
| 45 |
|
| 46 |
def create_prompt(anonymized_text: str) -> str:
|
presidio_helpers.py
CHANGED
|
@@ -16,10 +16,9 @@ from presidio_anonymizer import AnonymizerEngine
|
|
| 16 |
from presidio_anonymizer.entities import OperatorConfig
|
| 17 |
|
| 18 |
from openai_fake_data_generator import (
|
| 19 |
-
set_openai_params,
|
| 20 |
call_completion_model,
|
| 21 |
-
create_prompt,
|
| 22 |
OpenAIParams,
|
|
|
|
| 23 |
)
|
| 24 |
from presidio_nlp_engine_config import (
|
| 25 |
create_nlp_engine_with_spacy,
|
|
@@ -218,14 +217,9 @@ def create_fake_data(
|
|
| 218 |
if not openai_params.openai_key:
|
| 219 |
return "Please provide your OpenAI key"
|
| 220 |
results = anonymize(text=text, operator="replace", analyze_results=analyze_results)
|
| 221 |
-
set_openai_params(openai_params)
|
| 222 |
prompt = create_prompt(results.text)
|
| 223 |
print(f"Prompt: {prompt}")
|
| 224 |
-
fake =
|
| 225 |
-
prompt=prompt,
|
| 226 |
-
openai_model_name=openai_params.model,
|
| 227 |
-
openai_deployment_name=openai_params.deployment_name,
|
| 228 |
-
)
|
| 229 |
return fake
|
| 230 |
|
| 231 |
|
|
|
|
| 16 |
from presidio_anonymizer.entities import OperatorConfig
|
| 17 |
|
| 18 |
from openai_fake_data_generator import (
|
|
|
|
| 19 |
call_completion_model,
|
|
|
|
| 20 |
OpenAIParams,
|
| 21 |
+
create_prompt,
|
| 22 |
)
|
| 23 |
from presidio_nlp_engine_config import (
|
| 24 |
create_nlp_engine_with_spacy,
|
|
|
|
| 217 |
if not openai_params.openai_key:
|
| 218 |
return "Please provide your OpenAI key"
|
| 219 |
results = anonymize(text=text, operator="replace", analyze_results=analyze_results)
|
|
|
|
| 220 |
prompt = create_prompt(results.text)
|
| 221 |
print(f"Prompt: {prompt}")
|
| 222 |
+
fake = call_completion_model(prompt=prompt, openai_params=openai_params)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
return fake
|
| 224 |
|
| 225 |
|
presidio_streamlit.py
CHANGED
|
@@ -135,7 +135,8 @@ def set_up_openai_synthesis():
|
|
| 135 |
"Azure OpenAI base URL",
|
| 136 |
value=os.getenv("AZURE_OPENAI_ENDPOINT", default=""),
|
| 137 |
)
|
| 138 |
-
|
|
|
|
| 139 |
"Deployment name", value=os.getenv("AZURE_OPENAI_DEPLOYMENT", default="")
|
| 140 |
)
|
| 141 |
st_openai_version = st.sidebar.text_input(
|
|
@@ -143,11 +144,13 @@ def set_up_openai_synthesis():
|
|
| 143 |
value=os.getenv("OPENAI_API_VERSION", default="2023-05-15"),
|
| 144 |
)
|
| 145 |
else:
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
| 148 |
st_openai_key = st.sidebar.text_input(
|
| 149 |
"OPENAI_KEY",
|
| 150 |
-
value=
|
| 151 |
help="See https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key for more info.",
|
| 152 |
type="password",
|
| 153 |
)
|
|
@@ -159,7 +162,7 @@ def set_up_openai_synthesis():
|
|
| 159 |
return (
|
| 160 |
openai_api_type,
|
| 161 |
st_openai_api_base,
|
| 162 |
-
|
| 163 |
st_openai_version,
|
| 164 |
st_openai_key,
|
| 165 |
st_openai_model,
|
|
@@ -179,7 +182,7 @@ elif st_operator == "synthesize":
|
|
| 179 |
(
|
| 180 |
openai_api_type,
|
| 181 |
st_openai_api_base,
|
| 182 |
-
|
| 183 |
st_openai_version,
|
| 184 |
st_openai_key,
|
| 185 |
st_openai_model,
|
|
@@ -189,7 +192,7 @@ elif st_operator == "synthesize":
|
|
| 189 |
openai_key=st_openai_key,
|
| 190 |
model=st_openai_model,
|
| 191 |
api_base=st_openai_api_base,
|
| 192 |
-
|
| 193 |
api_version=st_openai_version,
|
| 194 |
api_type=openai_api_type,
|
| 195 |
)
|
|
|
|
| 135 |
"Azure OpenAI base URL",
|
| 136 |
value=os.getenv("AZURE_OPENAI_ENDPOINT", default=""),
|
| 137 |
)
|
| 138 |
+
openai_key = os.getenv("AZURE_OPENAI_KEY", default="")
|
| 139 |
+
st_deployment_id = st.sidebar.text_input(
|
| 140 |
"Deployment name", value=os.getenv("AZURE_OPENAI_DEPLOYMENT", default="")
|
| 141 |
)
|
| 142 |
st_openai_version = st.sidebar.text_input(
|
|
|
|
| 144 |
value=os.getenv("OPENAI_API_VERSION", default="2023-05-15"),
|
| 145 |
)
|
| 146 |
else:
|
| 147 |
+
openai_api_type = "openai"
|
| 148 |
+
st_openai_version = st_openai_api_base = None
|
| 149 |
+
st_deployment_id = ""
|
| 150 |
+
openai_key = os.getenv("OPENAI_KEY", default="")
|
| 151 |
st_openai_key = st.sidebar.text_input(
|
| 152 |
"OPENAI_KEY",
|
| 153 |
+
value=openai_key,
|
| 154 |
help="See https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key for more info.",
|
| 155 |
type="password",
|
| 156 |
)
|
|
|
|
| 162 |
return (
|
| 163 |
openai_api_type,
|
| 164 |
st_openai_api_base,
|
| 165 |
+
st_deployment_id,
|
| 166 |
st_openai_version,
|
| 167 |
st_openai_key,
|
| 168 |
st_openai_model,
|
|
|
|
| 182 |
(
|
| 183 |
openai_api_type,
|
| 184 |
st_openai_api_base,
|
| 185 |
+
st_deployment_id,
|
| 186 |
st_openai_version,
|
| 187 |
st_openai_key,
|
| 188 |
st_openai_model,
|
|
|
|
| 192 |
openai_key=st_openai_key,
|
| 193 |
model=st_openai_model,
|
| 194 |
api_base=st_openai_api_base,
|
| 195 |
+
deployment_id=st_deployment_id,
|
| 196 |
api_version=st_openai_version,
|
| 197 |
api_type=openai_api_type,
|
| 198 |
)
|