|
| 1 | +import logging |
| 2 | +import os |
| 3 | + |
| 4 | +import streamlit as st |
| 5 | + |
| 6 | +from recsys.config import settings |
| 7 | +from recsys.ui.feature_group_updater import get_fg_updater |
| 8 | +from recsys.ui.interaction_tracker import get_tracker |
| 9 | +from recsys.ui.recommenders import customer_recommendations, llm_recommendations |
| 10 | +from recsys.ui.utils import get_deployments |
| 11 | + |
| 12 | +# Configure logging |
| 13 | +logging.basicConfig(level=logging.INFO) |
| 14 | +logger = logging.getLogger(__name__) |
| 15 | + |
| 16 | +# Constants |
| 17 | +CUSTOMER_IDS = [ |
| 18 | + "9e619265e3ae0d2ef96a71577c4aff3474bfa7dd0d60486b42bc8f921c3387c0", |
| 19 | + "a1f7201399574e78b0a1575c50e3b68d116f84e24c0f70c957083da99db6ab5f", |
| 20 | + "19fa659096de20f0c022b9727779e849813ccc82952b3d56e212ab18fa2c0bf3", |
| 21 | + "d9448c8585f1678937deb5118d95b09bf6f41fe00a65b1fb82c7d176c6bfc532", |
| 22 | + "b41d990c8a127dac386dd6c9f2a6ec4ac41185cd21ef2df0a952a8cbdf61ed5d", |
| 23 | +] |
| 24 | + |
| 25 | + |
| 26 | +def initialize_page(): |
| 27 | + """Initialize Streamlit page configuration""" |
| 28 | + st.set_page_config(layout="wide", initial_sidebar_state="expanded") |
| 29 | + st.title("👒 Fashion Items Recommender") |
| 30 | + st.sidebar.title("⚙️ Configuration") |
| 31 | + |
| 32 | + |
| 33 | +def initialize_services(): |
| 34 | + """Initialize tracker, updater, and deployments""" |
| 35 | + tracker = get_tracker() |
| 36 | + fg_updater = get_fg_updater() |
| 37 | + |
| 38 | + logger.info("Initializing deployments...") |
| 39 | + with st.sidebar: |
| 40 | + with st.spinner("🚀 Starting Deployments..."): |
| 41 | + articles_fv, ranking_deployment, query_model_deployment = get_deployments() |
| 42 | + st.success("✅ Deployments Ready") |
| 43 | + |
| 44 | + # Stop deployments button |
| 45 | + if st.button( |
| 46 | + "⏹️ Stop Deployments", key="stop_deployments_button", type="secondary" |
| 47 | + ): |
| 48 | + ranking_deployment.stop() |
| 49 | + query_model_deployment.stop() |
| 50 | + st.success("Deployments stopped successfully!") |
| 51 | + |
| 52 | + return tracker, fg_updater, articles_fv, ranking_deployment, query_model_deployment |
| 53 | + |
| 54 | + |
| 55 | +def show_interaction_dashboard(tracker, fg_updater, page_selection): |
| 56 | + """Display interaction data and controls""" |
| 57 | + with st.sidebar.expander("📊 Interaction Dashboard", expanded=True): |
| 58 | + if page_selection == "LLM Recommendations": |
| 59 | + api_key = ( |
| 60 | + settings.OPENAI_API_KEY.get_secret_value() |
| 61 | + if settings.OPENAI_API_KEY |
| 62 | + and settings.OPENAI_API_KEY.get_secret_value() |
| 63 | + else None |
| 64 | + ) |
| 65 | + if not api_key: |
| 66 | + api_key = st.text_input( |
| 67 | + "🔑 OpenAI API Key:", type="password", key="openai_api_key" |
| 68 | + ) |
| 69 | + if api_key: |
| 70 | + os.environ["OPENAI_API_KEY"] = api_key |
| 71 | + else: |
| 72 | + st.warning("⚠️ Please enter OpenAI API Key for LLM Recommendations") |
| 73 | + st.divider() |
| 74 | + |
| 75 | + interaction_data = tracker.get_interactions_data() |
| 76 | + |
| 77 | + col1, col2, col3 = st.columns(3) |
| 78 | + total = len(interaction_data) |
| 79 | + clicks = len(interaction_data[interaction_data["interaction_score"] == 1]) |
| 80 | + purchases = len(interaction_data[interaction_data["interaction_score"] == 2]) |
| 81 | + |
| 82 | + col1.metric("Total", total) |
| 83 | + col2.metric("Clicks", clicks) |
| 84 | + col3.metric("Purchases", purchases) |
| 85 | + |
| 86 | + st.dataframe(interaction_data, hide_index=True) |
| 87 | + fg_updater.process_interactions(tracker, force=True) |
| 88 | + |
| 89 | + |
| 90 | +def handle_llm_page(articles_fv, customer_id): |
| 91 | + """Handle LLM recommendations page""" |
| 92 | + if "OPENAI_API_KEY" in os.environ: |
| 93 | + llm_recommendations(articles_fv, os.environ["OPENAI_API_KEY"], customer_id) |
| 94 | + else: |
| 95 | + st.warning("Please provide your OpenAI API Key in the Interaction Dashboard") |
| 96 | + |
| 97 | + |
| 98 | +def process_pending_interactions(tracker, fg_updater): |
| 99 | + """Process interactions immediately""" |
| 100 | + fg_updater.process_interactions(tracker, force=True) |
| 101 | + |
| 102 | + |
| 103 | +def main(): |
| 104 | + # Initialize page |
| 105 | + initialize_page() |
| 106 | + |
| 107 | + # Initialize services |
| 108 | + tracker, fg_updater, articles_fv, ranking_deployment, query_model_deployment = ( |
| 109 | + initialize_services() |
| 110 | + ) |
| 111 | + |
| 112 | + # Select customer |
| 113 | + customer_id = st.sidebar.selectbox( |
| 114 | + "👤 Select Customer:", CUSTOMER_IDS, key="selected_customer" |
| 115 | + ) |
| 116 | + |
| 117 | + # Page selection |
| 118 | + page_options = ["Customer Recommendations", "LLM Recommendations"] |
| 119 | + page_selection = st.sidebar.radio("📑 Choose Page:", page_options) |
| 120 | + |
| 121 | + # Process any pending interactions with notification |
| 122 | + process_pending_interactions(tracker, fg_updater) |
| 123 | + |
| 124 | + # Interaction dashboard with OpenAI API key field |
| 125 | + show_interaction_dashboard(tracker, fg_updater, page_selection) |
| 126 | + |
| 127 | + # Handle page content |
| 128 | + if page_selection == "Customer Recommendations": |
| 129 | + customer_recommendations( |
| 130 | + articles_fv, ranking_deployment, query_model_deployment, customer_id |
| 131 | + ) |
| 132 | + else: # LLM Recommendations |
| 133 | + handle_llm_page(articles_fv, customer_id) |
| 134 | + |
| 135 | + |
| 136 | +if __name__ == "__main__": |
| 137 | + main() |
0 commit comments