<?php
namespace boru\boruai\Openai;

use boru\boruai\BoruAI;
use boru\boruai\Models\Tool;
use boru\boruai\Models\ToolDefinition;
use boru\boruai\Openai\Api\Endpoints\ChatCompletionsAPI;
use boru\boruai\Openai\Api\Responses\ChatCompletionResponse;
use boru\boruai\Openai\OpenAI;
use boru\boruai\Openai\Models\BaseModel;
use boru\boruai\Openai\Models\Messages;
use boru\boruai\Tiktoken\Encoder;
use boru\boruai\Tiktoken\EncoderProvider;
use Exception;

class OpenAIChat extends BaseModel {

    //private $model = "gpt-3.5-turbo";
    private $model = "gpt-4";

    /**
     * @var Messages
     * A list of messages to use as context for the completion. Each message should be an object with a role key and a content key. The role key should be either user or assistant, and the content key should be the text of the message.
     */
    private $messages;

    /**
     * @var float|null
     * What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
     * We generally recommend altering this or top_p but not both.
     */
    private $temperature = null; // 1

    /**
     * @var float|null
     * An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
     * We generally recommend altering this or temperature but not both.
     */
    private $top_p = null; // 1

    /**
     * @var int|null
     * The maximum number of tokens to generate. The default is 64, and the maximum is 2048.
     */
    private $max_tokens = null; // 100

    /**
     * @var int
     * The number of candidates to return. We generally recommend this to be 1, but it can be set to any integer value.
     */
    private $n = 1;

    /**
     * @var bool
     * Whether to stream back partial progress. This is useful for long-running requests, such as generating an entire article, where you want to see progress as it's being generated.
     */
    private $stream = false;

    /**
     * @var array
     * A list of tokens which will cause the API to stop generating further tokens. This is useful if you want to control the length of the output.
     */
    private $stop = [];

    /**
     * @var float
     * A floating-point value that penalizes new tokens based on whether they appear in the text so far. Increasing this value makes the model more likely to talk about new topics, and decreasing it makes the model more likely to repeat the same line verbatim.
     */
    private $presence_penalty = 0;

    /**
     * @var float
     * A floating-point value that penalizes new tokens based on their existing frequency in the text so far. Increasing this value makes the model more likely to talk about new topics, and decreasing it makes the model more likely to repeat the same line verbatim.
     */
    private $frequency_penalty = 0;

    /**
     * @var ToolDefinition[]
     * A list of tools to provide additional context to the model. Each tool should be an object with a name key and an arguments key. The name key should be the name of the tool, and the arguments key should be an object containing the arguments to pass to the tool
     */
    private $tools = [];


    /**
     * @var EncoderProvider|null
     * The encoder to use for the model. This is used to encode the input text into tokens.
     */
    private $encoder = null;

    public function __construct($input=null,$options=[]) {
        if(!is_array($input)) {
            $this->model = $input;
        } else {
            $options = $input;
        }
        $this->messages = new Messages();
        if(isset($options["messages"])) {
            foreach($options["messages"] as $msg) {
                $this->addMessage($msg["role"],$msg["content"],isset($msg["name"]) ? $msg["name"] : "");
            }
            unset($options["messages"]);
        }
        parent::__construct($options);
        if($this->model === null) {
            $this->model = BoruAI::defaultModel("default");
        }
    }

    public function toArray($forSend=true) {
        $arr = [
            "model" => $this->model,
            "messages" => [],
        ];
        if(!is_null($this->temperature)) {
            $arr["temperature"] = $this->temperature;
        }
        if(!is_null($this->top_p)) {
            $arr["top_p"] = $this->top_p;
        }
        if(!is_null($this->max_tokens)) {
            $arr["max_tokens"] = $this->max_tokens;
        }
        $arr["n"] = $this->n;
        $arr["stream"] = $this->stream;
        if(!empty($this->stop)) {
            $arr["stop"] = $this->stop;
        }
        $arr["presence_penalty"] = $this->presence_penalty;
        $arr["frequency_penalty"] = $this->frequency_penalty;
        $arr["messages"] = $this->messages->toArray();
        if(!empty($this->tools)) {
            $arr["tools"] = [];
            foreach($this->tools as $tool) {
                $arr["tools"][] = [
                    "type"=>"function",
                    "function"=>$tool->compile()
                ];
            }
        }
        return $arr;
    }

    /**
     * @param array $parameters
     * @return array|ChatCompletionResponse
     * @throws \Exception
     */
    public static function create($parameters,$returnAsArray=false) {
        if($returnAsArray) {
            return OpenAI::request("post","chat/completions",$parameters,$returnAsArray);
        } else {
            return ChatCompletionsAPI::create($parameters);
        }
    }

    /**
     * 
     * @return ChatCompletionResponse|string|false 
     * @param int|false $choice  the index of the choice to return, or false to return the full response
     * @param bool $contentOnly  whether to return the content only or the full response
     * @throws Exception 
     */
    public function run($choice=0,$contentOnly=true) {
        $parameters = $this->toArray(true);
        $completion = ChatCompletionsAPI::create($parameters);
        if($choice === false) {
            return $completion;
        }
        $response = $completion->choice(0);
        if($response["finish_reason"] == "tool_calls") {
            $toolOutputs = [];
            foreach($response["message"]["tool_calls"] as $toolCall) {
                $tool = new Tool($toolCall);
                if(($toolOutput = $tool->run()) !== false) {
                    $toolResponse = [
                        "role" => "tool",
                        "content" => $toolOutput["output"],
                        "tool_call_id" => $toolOutput["tool_call_id"],
                    ];
                    $toolOutputs[] = $toolResponse;
                } else {
                    throw new Exception("Tool call not found ".json_encode($toolCall));
                }
            }
            $this->messages->add($response["message"]);
            $this->addToolCallMessage($toolOutputs);
            try {
                return $this->run($choice,$contentOnly);
            } catch(Exception $e) {
                //print_r($this->messages);
                throw $e;
            }
        }
        if($contentOnly) {
            $array = $completion->choice($choice);
            if(isset($array["message"]["content"])) {
                return $array["message"]["content"];
            }
            return $array;
        }
        return $completion->choice($choice);   
    }

    public function json($asArray=false) {
        $output = trim($this->run());
        if(substr($output,0,7)=="```json") {
            $output = ltrim(substr($output,7));
        }
        if(substr($output,-3)=="```") {
            $output = rtrim(substr($output,0,-3));
        }
        if($asArray) {
            $arr = json_decode($output, true);
            if($arr === null) {
                throw new Exception("Invalid JSON response: ".json_last_error_msg()."\n\n".$output."\n\n");
            }
            return $arr;
        }
        return $output;
    }

    public function model($model=null) {
        if($model !== null) {
            $this->model = $model;
        }
        return $this->model;
    }
    public function temperature($temperature=null) {
        if($temperature !== null) {
            $this->temperature = $temperature;
        }
        return $this->temperature;
    }
    public function topP($top_p=null) {
        if($top_p !== null) {
            $this->top_p = $top_p;
        }
        return $this->top_p;
    }
    public function maxTokens($max_tokens=null) {
        if($max_tokens !== null) {
            $this->max_tokens = $max_tokens;
        }
        return $this->max_tokens;
    }
    public function n($n=null) {
        if($n !== null) {
            $this->n = $n;
        }
        return $this->n;
    }
    public function stream($stream=null) {
        if($stream !== null) {
            $this->stream = $stream;
        }
        return $this->stream;
    }
    public function stop($stop=null) {
        if($stop !== null) {
            $this->stop = $stop;
        }
        return $this->stop;
    }
    public function presencePenalty($presence_penalty=null) {
        if($presence_penalty !== null) {
            $this->presence_penalty = $presence_penalty;
        }
        return $this->presence_penalty;
    }
    public function frequencyPenalty($frequency_penalty=null) {
        if($frequency_penalty !== null) {
            $this->frequency_penalty = $frequency_penalty;
        }
        return $this->frequency_penalty;
    }
    
    /**
     * Set the model to use for the chat
     */ 
    public function setModel($model) { $this->model($model); }
    public function setTemperature($temperature) { $this->temperature($temperature); }
    public function setTopP($top_p) { $this->topP($top_p); }
    public function setMaxTokens($max_tokens) { $this->maxTokens($max_tokens); }
    public function setN($n) { $this->n($n); }
    public function setStream($stream) { $this->stream($stream); }
    public function setStop($stop) { $this->stop($stop); }
    public function setPresencePenalty($presence_penalty) { $this->presencePenalty($presence_penalty); }
    public function setFrequencyPenalty($frequency_penalty) { $this->frequencyPenalty($frequency_penalty); }

    public function addToolCallMessage($toolCalls=[]) {
        foreach($toolCalls as $toolCall) {
            $this->messages->add($toolCall);
        }
    }
    public function addFile($fileIdOrPath,$type="file",$mimeType=null) {
        return $this->messages->addFile($fileIdOrPath,$type,$mimeType);
    }
    /**
     * Add a message to the list
     * @param mixed $role 
     * @param mixed $content 
     * @param mixed $fileIdOrPath 
     * @param mixed $fileType 
     * @param mixed $mimeType 
     * @return OpenAIMessage 
     * @throws Exception 
     */
    public function addMessage($role,$content,$fileIdOrPath=null,$fileType=null,$mimeType=null) {
        $message = null;
        if($fileIdOrPath) {
            $message = $this->messages->addFile($fileIdOrPath,$fileType,$mimeType);
        }
        $msg = [
            "role" => $role,
            "content" => $content,
        ];
        if(!$message) {
            $message = $this->messages->add($msg);
        } else {
            $message->text($content);
            $message->role($role);
        }
        return $message;
    }

    public function message($message) {
        if($this->messages === null) {
            $this->messages = new Messages();
        }
        $this->messages->add($message);
        return $this;
    }
    public function messages($messages=null) {
        if($messages !== null) {
            $this->messages = new Messages($messages);
        }
        return $this->messages;
    }

    public function tools($tools=null) {
        if($tools !== null) {
            $this->tools = $tools;
        }
        return $this->tools;
    }
    public function addTools($arrayOfTools=[]) {
        if($arrayOfTools === null || empty($arrayOfTools)) {
            $arrayOfTools = [];
        }
        foreach($arrayOfTools as $tool) {
            $this->addTool($tool);
        }
        return $this->tools;
    }
    public function addTool($tool) {
        if($tool instanceof ToolDefinition) {
            $this->tools[] = $tool;
        } else {
            if(is_string($tool)) {
                $json = json_decode($tool,true);
                if($json) {
                    $tool = new ToolDefinition($json);
                } else {
                    $tools = BoruAI::loadTool($tool);
                    if($tools) {
                        foreach($tools as $tool) {
                            $this->tools[] = $tool;
                        }
                    }
                }
            } else {
                $this->tools[] = new ToolDefinition($tool);
            }
        }
    }

    public function handleError($responseJson) {
        if(!is_array($responseJson)) {
            $responseJson = json_decode($responseJson,true);
        }
        if(!isset($responseJson["error"])) {
            return;
        }
        if(isset($responseJson["error"]["code"]) && $responseJson["error"]["code"] == "context_length_exceeded") {
            if(isset($responseJson["error"]["param"]) && $responseJson["error"]["param"] == "messages") {
                $keepLastN = 1;
                if($this->messages()->count()> 3) {
                    $keepLastN = 3;
                }
                $this->compress($keepLastN);
                return true;
            }
        }
        return false;
    }

    public function clearMessages() {
        $this->messages = new Messages();
    }
    public function compress($excludeLastN=2,$startAt=2) {
        $compress = $this->messages->compress($excludeLastN,$startAt);
        return $compress;
    }

    public static function fromArray($options=[]) {
        $chat = new OpenAIChat($options);
        if(isset($options["messages"])) {
            foreach($options["messages"] as $msg) {
                $chat->addMessage($msg["role"],$msg["content"],isset($msg["name"]) ? $msg["name"] : "");
            }
        }
        return $chat;
    }

    /**
     * @param EncoderProvider|null $encoderProvider
     * @return Encoder
     */
    public function encoder($encoderProvider=null) {
        if ($encoderProvider !== null) {
            $this->encoder = $encoderProvider;
        } elseif($encoderProvider === false) {
            $this->encoder = null;
        }
        if(!$this->encoder) {
            $provider = new EncoderProvider();
            $this->encoder = $provider->getForModel($this->model());
        }
        return $this->encoder;
    }
    public function encode($text) {
        return $this->encoder()->encode($text);
    }
    public function decode($text) {
        return $this->encoder()->decode($text);
    }
}