import sys
import os
import pika
from pymongo import MongoClient
from bson import ObjectId, json_util
from openai import OpenAI, RateLimitError
from retry import retry
import json
from config import LLM, MONGO
from utils import get_channel, now, get_logger
from .classroom_session import ClassroomSession
from .functions import get_function
from enum import StrEnum
[docs]
class INCLASS_SERVICE_STATUS(StrEnum):
    '''
    Define the status of an in-class session.
    
    Status:
    - `TRIGGERED`: the inclass worker is triggered in mongo
    - `CLASS_ENDED`: the inclass worker is finished
    - `STREAMING`: the inclass worker is waiting llms to generate
    - `PROCESSED`: the inclass worker is processed 
    '''
    TRIGGERED = "triggered" # the inclass worker is triggered in mongo
    CLASS_ENDED = "class_ended" # the inclass worker is finished
    STREAMING = "streaming" # the inclass worker is waiting llms to generate
    PROCESSED = "processed" # the inclass worker is processed  
[docs]
class INCLASS_SERVICE:
    """
    A service class for handling in-class session. It has four main methods:
    - trigger: Triggers an in-class session by processing user input and updating the session state.
    - launch_worker: Launches a worker that listens to a RabbitMQ queue and processes session jobs.
    - get_status: Retrieves the current status of a given in-class session.
    - get_updates: Retrieves updates for a given in-class session. (NOT IMPLEMENTED YET)
    """
    collection = MongoClient(
        MONGO.HOST,
        MONGO.PORT
    ).inclass.session
    queue_name = "inclass-main"
    logger = get_logger(
        __name__=__name__,
        __file__=__file__,
    )
[docs]
    @staticmethod
    def trigger(
            session_id,
            user_input:str="", # Set this to str or dict for user input string or user interactions
            parent_service:str="",
            parent_job_id= None, # set this to be ObjectId if need callback
    ) -> str:
        """
        Triggers an in-class session by processing user input and updating the session state.
        Args:
            session_id (str): The unique identifier for the session.
            user_input (str, optional): The user input string or interactions. Defaults to an empty string.
            parent_service (str, optional): The parent service identifier.
            parent_job_id (ObjectId, optional): The parent job identifier for callback purposes.
        Returns:
            str: The session ID of the triggered session.
        """
        connection = pika.BlockingConnection(
            pika.ConnectionParameters(host='localhost'))
        channel = connection.channel()
        channel.queue_declare(
            queue=INCLASS_SERVICE.queue_name,
            durable=True
        )
        session = ClassroomSession(session_id)
        function_session = session.get_current_function()
        if not function_session:
            INCLASS_SERVICE.logger.info(f"Class is over for Session ID {session_id}")
            return session_id
        executor_name = function_session['call']
        # print(function_session)
        if user_input:
            INCLASS_SERVICE.logger.info("Adding User Input to History")
            user_input = user_input.strip()
            if executor_name == 'AskQuestion':
                session.add_user_message(user_input, 'answer')
            else:
                session.add_user_message(user_input)
        INCLASS_SERVICE.logger.info("Setting Session Status to Triggered in Mongo")
        INCLASS_SERVICE.collection.update_one(
            dict(
                _id=ObjectId(session_id)
            ),
            {
                "$set": {
                    "state": INCLASS_SERVICE_STATUS.TRIGGERED.value,
                    "parent_service": parent_service,
                    "parent_job_id": parent_job_id,
                }
            }
        )
        INCLASS_SERVICE.logger.info("Pushing job to RabbitMQ")
        channel.basic_publish(
            exchange="",
            routing_key=INCLASS_SERVICE.queue_name,
            body=session_id
        )
        connection.close()
        INCLASS_SERVICE.logger.info("Job pushed to RabbitMQ")
        return session_id 
[docs]
    @staticmethod
    def launch_worker():
        """
        Launches a worker that listens to a RabbitMQ queue and processes session jobs.
        Continuously consumes tasks from a RabbitMQ queue, executes in-class session logic,
        and manages session state based on interaction and generation outcomes.
        Returns:
            None
        """
        try:
            connection = pika.BlockingConnection(pika.ConnectionParameters(host='localhost'))
            channel = connection.channel()
            channel.queue_declare(
                queue=INCLASS_SERVICE.queue_name,
                durable=True
            )
            def callback(ch, method, properties, body):
                session_id = ObjectId(body.decode())
                INCLASS_SERVICE.logger.info(f"Entering InClass Session Job - {session_id}")
                session = ClassroomSession(session_id)
                continue_generate = False
                if session.is_streaming(): # If LLM is still generating
                    
                    INCLASS_SERVICE.logger.info(f"Entering is_streaming - {session_id}")
                    continue_generate = True
                else:
                    function_session = session.get_current_function()
                    if not function_session: # If Classroom Already Finished
                        INCLASS_SERVICE.collection.update_one(
                            dict(_id=ObjectId(session_id)),
                            { "$set": { "state": INCLASS_SERVICE_STATUS.CLASS_ENDED.value, } }
                        )
                        ch.basic_ack(delivery_tag = method.delivery_tag)
                        return
                    function_id = str(function_session['_id'])
                    executor_name = function_session['call']
                    executor = get_function(executor_name)
                    value = function_session['value']
                    continue_generate = executor.step(
                        value=value,
                        function_id=function_id,
                        classroom_session=session,
                    )
                if continue_generate:
                    INCLASS_SERVICE.collection.update_one(
                        dict(
                            _id=ObjectId(session_id)
                        ),
                        { "$set": { "state": INCLASS_SERVICE_STATUS.STREAMING.value,} }
                    )
                    ch.basic_reject(delivery_tag=method.delivery_tag, requeue=True)
                else:
                    INCLASS_SERVICE.collection.update_one(
                        dict(
                            _id=ObjectId(session_id)
                        ),
                        { "$set": { "state": INCLASS_SERVICE_STATUS.PROCESSED.value,} }
                    )
                    ch.basic_ack(delivery_tag = method.delivery_tag)
                INCLASS_SERVICE.logger.info(f"Session Processed {session_id}")
            
            channel.basic_consume(
                queue=INCLASS_SERVICE.queue_name,
                on_message_callback=callback,
                auto_ack=False,
            )
            INCLASS_SERVICE.logger.info('Worker Launched. To exit press CTRL+C')
            channel.start_consuming()
        except KeyboardInterrupt:
            INCLASS_SERVICE.logger.warning('Shutting Off Worker')
            try:
                sys.exit(0)
            except SystemExit:
                os._exit(0) 
        
[docs]
    @staticmethod
    def get_status(
        session_id=str,
    ) -> str:
        """
        Retrieves the current status of a given in-class session.
        Args:
            session_id (str): The unique identifier for the session.
        Returns:
            str: A JSON string representation of the session status or an error message if not found.
        """
        try:
            # get session status
            session_status = INCLASS_SERVICE.collection.find_one(
                dict(
                    _id=ObjectId(session_id)
                )
            )
            if not session_status:
                return "Session not found"
            # Convert the MongoDB document to a JSON string
            session_status_json = json.dumps(session_status, default=json_util.default)
            return session_status_json  # Return serialized JSON
        except Exception as e:
            INCLASS_SERVICE.logger.error(f"Error in INCLASS_SERVICE get_status: {e}")
            return "ERROR in INCLASS_SERVICE get_status" 
        
    
[docs]
    @staticmethod
    def get_updates():
        pass 
 
    
# # TODO: Write Get Update Logic
# @staticmethod
# def get_response(job_id):
# 	record = INCLASS_SERVICE.collection.find_one(dict(_id=job_id))
# 	if not record:
# 		INCLASS_SERVICE.logger.error(f"Job With ID of {job_id} not found")
# 	elif "completion_time" not in record:
# 		INCLASS_SERVICE.logger.error(f"Retrieving Response From Un-Finished Job With ID of {job_id}.")
# 	else:
# 		return record["response"]
# 	return None
if __name__=="__main__":
    INCLASS_SERVICE.logger.warning("STARTING INCLASS SERVICE")
    INCLASS_SERVICE.launch_worker()