-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathsetup_page.py
More file actions
71 lines (61 loc) · 2.31 KB
/
setup_page.py
File metadata and controls
71 lines (61 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import streamlit as st
from promoai.general_utils.ai_providers import (
AI_HELP_DEFAULTS,
AI_MODEL_DEFAULTS,
AIProviders,
MAIN_HELP,
)
from promoai.general_utils.llm_connection import LLMConnection
def run_page():
# Initialize State
if "provider" not in st.session_state:
st.session_state["provider"] = list(AI_MODEL_DEFAULTS.keys())[0]
if "model_name" not in st.session_state:
st.session_state["model_name"] = AI_MODEL_DEFAULTS[st.session_state["provider"]]
def on_provider_change():
st.session_state["model_name"] = AI_MODEL_DEFAULTS[st.session_state["provider"]]
# UI Elements
st.write("### 🔑 API Configuration")
st.markdown('<div class="api-config-marker"></div>', unsafe_allow_html=True)
with st.container(border=True):
provider = st.selectbox(
"AI Provider",
options=AI_MODEL_DEFAULTS.keys(),
key="provider",
on_change=on_provider_change,
help=MAIN_HELP,
)
col1, col2 = st.columns(2)
with col1:
ai_model_name = st.text_input(
"Model Name",
key="model_name",
help=AI_HELP_DEFAULTS.get(st.session_state["provider"], ""),
)
with col2:
api_key = st.text_input(
"API Key", type="password", placeholder="my-precious-api-key"
)
azure_endpoint = None
if provider == AIProviders.AZURE.value:
azure_endpoint = st.text_input(
"Azure Endpoint",
key="azure_endpoint",
placeholder="https://your-resource.openai.azure.com/",
)
if st.button("Save Credentials", type="primary", use_container_width=True):
if not api_key:
st.error("Please enter an API key.")
else:
args = {"END_POINT": azure_endpoint}
st.session_state["llm_credentials"] = LLMConnection(
api_key=api_key,
llm_name=ai_model_name,
ai_provider=provider,
args=args,
)
st.success(
"Credentials saved! You can now navigate to ProMoAI or PMAx."
)
if __name__ in {"__main__", "__page__"}:
run_page()