summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/administration_reader.cpp1
-rw-r--r--src/administration_writer.cpp1
-rw-r--r--src/ai_providers/DeepSeek.cpp4
-rw-r--r--src/ai_providers/openAI.cpp37
-rw-r--r--src/importer.cpp44
-rw-r--r--src/locales/en.cpp1
-rw-r--r--src/strops.cpp3
-rw-r--r--src/ui/ui_settings.cpp55
8 files changed, 140 insertions, 6 deletions
diff --git a/src/administration_reader.cpp b/src/administration_reader.cpp
index 73fc788..c22638b 100644
--- a/src/administration_reader.cpp
+++ b/src/administration_reader.cpp
@@ -407,6 +407,7 @@ bool administration_reader::import_administration_info(char* buffer, size_t buff
ai_service ai_service;
ai_service.provider = (ai_provider)xml_get_s32_x(root, "AIService", "Provider", 0);
xml_get_str_x(root, ai_service.api_key_public, MAX_LEN_API_KEY, "AIService", "PublicKey", 0);
+ xml_get_str_x(root, ai_service.model_name, MAX_LEN_SHORT_DESC, "AIService", "Model", 0);
administration::set_ai_service(ai_service);
logger::info("Loaded administration info in %.3fms. next_id=%d next_sequence_number=%d",
diff --git a/src/administration_writer.cpp b/src/administration_writer.cpp
index 81c90fa..5010dd8 100644
--- a/src/administration_writer.cpp
+++ b/src/administration_writer.cpp
@@ -791,6 +791,7 @@ bool administration_writer::save_all_administration_info_blocking()
ai_service ai_service = administration::get_ai_service();
strops::replace_int32(file_content, buf_length, "{{AI_SERVICE_PROVIDER}}", (s32)ai_service.provider);
strops::replace(file_content, buf_length, "{{AI_SERVICE_PUBLIC_KEY}}", ai_service.api_key_public);
+ strops::replace(file_content, buf_length, "{{AI_SERVICE_MODEL}}", ai_service.model_name);
//// Write to Disk.
int final_length = (int)strlen(file_content);
diff --git a/src/ai_providers/DeepSeek.cpp b/src/ai_providers/DeepSeek.cpp
index a1857e9..2bc4dde 100644
--- a/src/ai_providers/DeepSeek.cpp
+++ b/src/ai_providers/DeepSeek.cpp
@@ -46,7 +46,7 @@ static bool _DeepSeek_query_with_file(char* query, size_t query_length, char* fi
size_t body_size = file_size + QUERY_BUFFER_SIZE;
char* body = (char*)memops::alloc(body_size);
strops::format(body, body_size,
- "{\"model\":\"deepseek-reasoner\", \"messages\": [ { \"role\": \"user\", \"content\": \"%s\" } ] }", query_escaped);
+ "{\"model\":\"%s\", \"messages\": [ { \"role\": \"user\", \"content\": \"%s\" } ] }", administration::get_ai_service().model_name, query_escaped);
httplib::Headers headers;
headers.insert(std::make_pair("Authorization", std::string("Bearer ") + api_key));
@@ -113,6 +113,8 @@ static bool _DeepSeek_upload_file(char* file_path, char* file_id, size_t file_id
importer::ai_provider_impl _deepseek_api_provider = {
"DeekSeek",
+ "deepseek-reasoner",
_DeepSeek_upload_file,
_DeepSeek_query_with_file,
+ 0,
}; \ No newline at end of file
diff --git a/src/ai_providers/openAI.cpp b/src/ai_providers/openAI.cpp
index b55f191..fba050c 100644
--- a/src/ai_providers/openAI.cpp
+++ b/src/ai_providers/openAI.cpp
@@ -35,8 +35,8 @@ static bool _openAI_query_with_file(char* query, size_t query_length, char* file
size_t body_size = query_length + 200;
char* body = (char*)memops::alloc(body_size);
strops::format(body, body_size,
- "{\"model\":\"gpt-5-nano\", \"input\": [ { \"role\": \"user\", \"content\": [ { \"type\": \"input_file\", \"file_id\": \"%s\" }, "
- "{ \"type\": \"input_text\", \"text\": \"%s\" } ] } ] }", file_id, query_escaped);
+ "{\"model\":\"%s\", \"input\": [ { \"role\": \"user\", \"content\": [ { \"type\": \"input_file\", \"file_id\": \"%s\" }, "
+ "{ \"type\": \"input_text\", \"text\": \"%s\" } ] } ] }", administration::get_ai_service().model_name, file_id, query_escaped);
httplib::Headers headers;
headers.insert(std::make_pair("Authorization", std::string("Bearer ") + api_key));
@@ -168,8 +168,41 @@ static bool _openAI_upload_file(char* file_path, char* file_id, size_t file_id_l
return 1;
}
+static bool _openAI_get_available_models(importer::model_list_request* buffer)
+{
+ const char *api_key = administration::get_ai_service().api_key_public;
+
+ httplib::SSLClient cli("api.openai.com", 443);
+
+ httplib::Headers headers;
+ headers.insert(std::make_pair("Authorization", std::string("Bearer ") + api_key));
+
+ httplib::Result res = cli.Get("/v1/models", headers);
+ if (!res || res->status != 200) {
+ logger::error("ERROR Failed to get models list.");
+ logger::error(res->body.c_str());
+ return 0;
+ }
+
+ char* completion_body_response = (char*)res->body.c_str();
+
+ u32 count = 0;
+ char model_name[MAX_LEN_SHORT_DESC];
+
+ while(1) {
+ if (!strops::get_json_value(completion_body_response, "id", model_name, MAX_LEN_SHORT_DESC, count++)) break;
+ if (count == MAX_MODEL_LIST_RESULT_COUNT) break;
+
+ strops::copy(buffer->result[buffer->result_count++], model_name, MAX_LEN_SHORT_DESC);
+ }
+
+ return 1;
+}
+
importer::ai_provider_impl _chatgpt_api_provider = {
"OpenAI",
+ "gpt-5-nano",
_openAI_upload_file,
_openAI_query_with_file,
+ _openAI_get_available_models,
}; \ No newline at end of file
diff --git a/src/importer.cpp b/src/importer.cpp
index 23960b1..e40de5b 100644
--- a/src/importer.cpp
+++ b/src/importer.cpp
@@ -160,6 +160,50 @@ importer::invoice_request* importer::ai_document_to_invoice(char* file_path)
return result;
}
+static int _ai_get_available_models_t(void* arg)
+{
+ importer::model_list_request* request = (importer::model_list_request*)arg;
+ importer::ai_provider_impl impl = importer::get_ai_provider_implementation(request->service);
+
+ if (!impl.get_available_models) {
+ request->status = importer::status::IMPORT_DONE;
+ request->error = I_ERR_UNIMPLEMENTED;
+ return 0;
+ }
+
+ request->status = importer::status::IMPORT_WAITING_FOR_RESPONSE;
+
+ if (!impl.get_available_models(request)) {
+ request->status = importer::status::IMPORT_DONE;
+ request->error = I_ERR_FAILED_QUERY;
+ return 0;
+ }
+
+ request->status = importer::status::IMPORT_DONE;
+
+ return 0;
+}
+
+importer::model_list_request* importer::ai_get_available_models(ai_provider service)
+{
+ importer::model_list_request* result = (importer::model_list_request*)memops::alloc(sizeof(importer::model_list_request));
+ result->started_at = time(NULL);
+ result->error = I_ERR_SUCCESS;
+ result->status = importer::status::IMPORT_STARTING;
+ result->result_count = 0;
+ result->service = service;
+ memset(result->result, 0, sizeof(result->result));
+
+ thrd_t thr;
+ if (thrd_create(&thr, _ai_get_available_models_t, result) != thrd_success) {
+ result->status = importer::status::IMPORT_DONE;
+ result->error = I_ERR_FAILED_QUERY;
+ return 0;
+ }
+
+ return result;
+}
+
const char* importer::status_to_string(importer::status status)
{
switch(status)
diff --git a/src/locales/en.cpp b/src/locales/en.cpp
index 20d482e..4abcab8 100644
--- a/src/locales/en.cpp
+++ b/src/locales/en.cpp
@@ -151,6 +151,7 @@ locale_entry en_locales[] = {
{"settings.costcenters.table.description", "Description"},
{"settings.services.ai_service", "AI Service"},
{"settings.services.ai_service.provider", "Provider"},
+ {"settings.services.ai_service.model", "Model"},
{"settings.services.ai_service.privkey", "Public key"},
{"settings.services.ai_service.pubkey", "Private key"},
diff --git a/src/strops.cpp b/src/strops.cpp
index d47ec2e..efa91e6 100644
--- a/src/strops.cpp
+++ b/src/strops.cpp
@@ -143,7 +143,8 @@ namespace strops {
const char *pos = strstr(json, pattern);
while(skip > 0) {
pos = strstr(pos+1, pattern);
- skip--;
+ skip--;
+ if (!pos) return 0;
}
if (!pos) return NULL;
pos = strchr(pos, ':');
diff --git a/src/ui/ui_settings.cpp b/src/ui/ui_settings.cpp
index dd59323..cf27b49 100644
--- a/src/ui/ui_settings.cpp
+++ b/src/ui/ui_settings.cpp
@@ -335,6 +335,9 @@ static void draw_cost_centers()
static void draw_services()
{
+ static importer::model_list_request* model_request = 0;
+ static bool set_model_on_load = false;
+
// AI service
if (ImGui::CollapsingHeader(locale::get("settings.services.ai_service")))
{
@@ -350,6 +353,8 @@ static void draw_services()
bool is_selected = n == (uint32_t)new_service.provider;
if (ImGui::Selectable(ai_service_names[n], is_selected)) {
new_service.provider = (ai_provider)n;
+ model_request = 0;
+ set_model_on_load = true;
}
}
ImGui::EndCombo();
@@ -357,10 +362,56 @@ static void draw_services()
ImGui::InputTextWithHint(locale::get("settings.services.ai_service.pubkey"), locale::get("settings.services.ai_service.pubkey"),
new_service.api_key_public, sizeof(new_service.api_key_public));
+
+ if (!model_request) {
+ model_request = importer::ai_get_available_models(new_service.provider);
+ }
+
+ // Default to first result in model list, or hardcoded default model.
+ if (model_request->status == importer::status::IMPORT_DONE) {
+ if (set_model_on_load) {
+ set_model_on_load = false;
+
+ if (model_request->result_count > 0) {
+ strops::copy(new_service.model_name, model_request->result[0], sizeof(new_service.model_name));
+ }
+ else {
+ strops::copy(new_service.model_name, importer::get_ai_provider_implementation(new_service.provider).default_model, sizeof(new_service.model_name));
+ }
+ }
+ }
+
+ if (model_request->status == importer::status::IMPORT_DONE && model_request->error == I_ERR_SUCCESS) {
+ if (ImGui::BeginCombo(locale::get("settings.services.ai_service.model"), new_service.model_name))
+ {
+ for (u32 n = 0; n < model_request->result_count; n++)
+ {
+ bool is_selected = strops::equals(new_service.model_name, model_request->result[n]);
+ if (ImGui::Selectable(model_request->result[n], is_selected)) {
+ strops::copy(new_service.model_name, model_request->result[n], sizeof(new_service.model_name));
+ }
+ }
+ ImGui::EndCombo();
+ }
+ }
+ else {
+ ImGui::BeginDisabled();
+ if (ImGui::BeginCombo(locale::get("settings.services.ai_service.model"), new_service.model_name))
+ {
+ ImGui::EndCombo();
+ }
+ if (model_request->status != importer::status::IMPORT_DONE) {
+ ImGui::SameLine();
+
+ // TODO replace with LoadingIndicatorCircle
+ ImGui::Text("%c", "|/-\\"[(int)(ImGui::GetTime() / 0.05f) & 3]);
+ }
+ ImGui::EndDisabled();
+ }
if (ImGui::Button(locale::get("form.save"))) {
- administration::set_ai_service(new_service);
- }
+ administration::set_ai_service(new_service);
+ }
}
}