diff options
| -rw-r--r-- | include/administration.hpp | 18 | ||||
| -rw-r--r-- | include/file_templates.hpp | 16 | ||||
| -rw-r--r-- | src/administration.cpp | 18 | ||||
| -rw-r--r-- | src/administration_reader.cpp | 33 | ||||
| -rw-r--r-- | src/administration_writer.cpp | 33 | ||||
| -rw-r--r-- | src/importer.cpp | 8 | ||||
| -rw-r--r-- | src/providers/DeepSeek.cpp | 4 | ||||
| -rw-r--r-- | src/providers/openAI.cpp | 30 | ||||
| -rw-r--r-- | src/ui/ui_settings.cpp | 11 | ||||
| -rw-r--r-- | tests/administration_rw_tests.cpp | 6 |
10 files changed, 138 insertions, 39 deletions
diff --git a/include/administration.hpp b/include/administration.hpp index 5c39d59..2179301 100644 --- a/include/administration.hpp +++ b/include/administration.hpp @@ -44,6 +44,8 @@ #define MAX_LEN_PROJECT_REPORT_COSTCENTERS 50 #define MAX_BILLING_ITEMS 50 +#define MAX_AI_SERVICES AI_PROVIDER_END + #define ACTIVITY_MAX_PARAMS 3 #define ACTIVITY_USER "user" #define ACTIVITY_SYSTEM "system" @@ -394,7 +396,8 @@ typedef void (*project_changed_event)(project* project); typedef enum { AI_PROVIDER_OPENAI = 0, - //AI_PROVIDER_DEEPSEEK = 1, + AI_PROVIDER_GEMINI = 1, + AI_PROVIDER_PERPLEXITY = 2, AI_PROVIDER_END, } ai_provider; @@ -403,7 +406,7 @@ typedef struct { ai_provider provider; char model_name[MAX_LEN_SHORT_DESC]; - char api_key_public[MAX_LEN_API_KEY]; + char api_key_public[MAX_LEN_API_KEY]; // @TODO rename to api_key } ai_service; typedef enum @@ -449,7 +452,10 @@ typedef struct list_t invoices; // Service providers. - ai_service ai_service; + u32 ai_service_count; + ai_service all_ai_services[MAX_AI_SERVICES]; + ai_service ai_service; // @TODO rename to active_ai_service + email_service email_service; } ledger; @@ -518,12 +524,14 @@ namespace administration { char* get_currency_symbol_for_currency(char* code); char* get_default_currency(); time_t get_default_invoice_expire_duration(); - ai_service get_ai_service(); + ai_service get_active_ai_service(); + ai_service get_ai_service(ai_provider provider); email_service get_email_service(); void set_file_path(char* path); void set_next_id(s32 nr); void set_next_sequence_number(s32 nr); - void set_ai_service(ai_service provider); + void set_active_ai_service(ai_service provider); + void import_ai_service(ai_service provider); void set_email_service(email_service provider); void create_income_statement(income_statement* statement); bool company_info_is_valid(); diff --git a/include/file_templates.hpp b/include/file_templates.hpp index ec78cb7..b79b5bf 100644 --- a/include/file_templates.hpp +++ b/include/file_templates.hpp @@ -61,16 +61,22 @@ namespace file_template { " </Address>\n" "</Contact>"; + static const char* ai_service_template = + " <AIService>\n" + " <Provider>{{AI_SERVICE_PROVIDER}}</Provider>\n" + " <PublicKey>{{AI_SERVICE_KEY}}</PublicKey>\n" + " <Model>{{AI_SERVICE_MODEL}}</Model>\n" + " </AIService>\n"; + static const char* administration_save_template = "<Administration>\n" " <NextId>{{NEXT_ID}}</NextId>\n" " <NextSequenceNumber>{{NEXT_SEQUENCE_NUMBER}}</NextSequenceNumber>\n" " <ProgramVersion>{{PROGRAM_VERSION}}</ProgramVersion>\n" - " <AIService>\n" - " <Provider>{{AI_SERVICE_PROVIDER}}</Provider>\n" - " <PublicKey>{{AI_SERVICE_PUBLIC_KEY}}</PublicKey>\n" - " <Model>{{AI_SERVICE_MODEL}}</Model>\n" - " </AIService>\n" + " <AIServices>\n" + " <Provider>{{ACTIVE_AI_SERVICE_PROVIDER}}</Provider>\n" + " {{AI_SERVICE_LIST}}\n" + " </AIServices>\n" " <EmailService>\n" " <Provider>{{EMAIL_SERVICE_PROVIDER}}</Provider>\n" " <PublicKey>{{EMAIL_SERVICE_KEY}}</PublicKey>\n" diff --git a/src/administration.cpp b/src/administration.cpp index 182dabd..9b0144e 100644 --- a/src/administration.cpp +++ b/src/administration.cpp @@ -135,6 +135,7 @@ void administration_create() strops::copy(g_administration.path, "", sizeof(g_administration.path)); memops::zero(&g_administration.ai_service, sizeof(ai_service)); + memops::zero(&g_administration.all_ai_services, sizeof(g_administration.all_ai_services)); // Load all tax rates. for (s32 i = 0; i < country::get_count(); i++) @@ -232,11 +233,16 @@ void administration::create_default(char* save_file) // Other functions. // ======================= -ai_service administration::get_ai_service() +ai_service administration::get_active_ai_service() { return g_administration.ai_service; } +ai_service administration::get_ai_service(ai_provider provider) +{ + return g_administration.all_ai_services[(u32)provider]; +} + email_service administration::get_email_service() { return g_administration.email_service; @@ -248,9 +254,15 @@ void administration::set_email_service(email_service provider) if (administration_data_changed_event_callback) administration_data_changed_event_callback(); } -void administration::set_ai_service(ai_service provider) +void administration::import_ai_service(ai_service service) +{ + g_administration.all_ai_services[(u32)service.provider] = service; +} + +void administration::set_active_ai_service(ai_service service) { - g_administration.ai_service = provider; + g_administration.all_ai_services[(u32)service.provider] = service; + g_administration.ai_service = service; if (administration_data_changed_event_callback) administration_data_changed_event_callback(); } diff --git a/src/administration_reader.cpp b/src/administration_reader.cpp index a5f2449..0b5c233 100644 --- a/src/administration_reader.cpp +++ b/src/administration_reader.cpp @@ -460,11 +460,34 @@ bool administration_reader::import_administration_info(char* buffer, size_t buff administration::set_next_id(xml_get_s32(root, "NextId")); administration::set_next_sequence_number(xml_get_s32(root, "NextSequenceNumber")); - ai_service ai_service; - ai_service.provider = (ai_provider)xml_get_s32_x(root, "AIService", "Provider", 0); - xml_get_str_x(root, ai_service.api_key_public, MAX_LEN_API_KEY, "AIService", "PublicKey", 0); - xml_get_str_x(root, ai_service.model_name, MAX_LEN_SHORT_DESC, "AIService", "Model", 0); - administration::set_ai_service(ai_service); + { // Load AI services. + xml_node* ai_service_root = xml_easy_child(root, (uint8_t *)"AIServices", 0); + + ai_service active_ai_service; + active_ai_service.provider = (ai_provider)xml_get_s32_x(ai_service_root, "Provider", 0); + + size_t ai_service_count = xml_node_children(ai_service_root); + for (size_t x = 0; x < ai_service_count; x++) + { + xml_node* child = xml_node_child(ai_service_root, x); + + char* child_name = (char*)xml_easy_name(child); + if (strops::equals(child_name, "AIService")) + { + ai_service service; + service.provider = (ai_provider)xml_get_s32_x(child, "Provider", 0); + xml_get_str_x(child, service.api_key_public, MAX_LEN_API_KEY, "PublicKey", 0); + xml_get_str_x(child, service.model_name, MAX_LEN_SHORT_DESC, "Model", 0); + + administration::import_ai_service(service); + } + + memops::unalloc(child_name); + } + + active_ai_service = administration::get_ai_service(active_ai_service.provider); + administration::set_active_ai_service(active_ai_service); + } email_service email_service; email_service.provider = (email_provider)xml_get_s32_x(root, "EmailService", "Provider", 0); diff --git a/src/administration_writer.cpp b/src/administration_writer.cpp index da9ee77..c1637f4 100644 --- a/src/administration_writer.cpp +++ b/src/administration_writer.cpp @@ -921,10 +921,35 @@ bool administration_writer::save_administration_info_blocking() strops::replace_int32(file_content, buf_length, "{{NEXT_SEQUENCE_NUMBER}}", administration::get_next_sequence_number()); strops::replace(file_content, buf_length, "{{PROGRAM_VERSION}}", config::PROGRAM_VERSION); - ai_service ai_service = administration::get_ai_service(); - strops::replace_int32(file_content, buf_length, "{{AI_SERVICE_PROVIDER}}", (s32)ai_service.provider); - strops::replace(file_content, buf_length, "{{AI_SERVICE_PUBLIC_KEY}}", ai_service.api_key_public); - strops::replace(file_content, buf_length, "{{AI_SERVICE_MODEL}}", ai_service.model_name); + ai_service active_ai_service = administration::get_active_ai_service(); + strops::replace_int32(file_content, buf_length, "{{ACTIVE_AI_SERVICE_PROVIDER}}", (s32)active_ai_service.provider); + + { + u32 ai_service_list_buffer_size = (u32)(strops::length(file_template::ai_service_template) + 1000) * AI_PROVIDER_END; // @TODO + char* ai_service_list_buffer = (char*)memops::alloc(ai_service_list_buffer_size); + memops::zero(ai_service_list_buffer, ai_service_list_buffer_size); + u32 ai_service_list_buffer_cursor = 0; + + for (u32 i = 0; i < AI_PROVIDER_END; i++) + { + int buf_length = 0; + char* file_content = copy_template(file_template::ai_service_template, &buf_length); + + ai_service service = administration::get_ai_service((ai_provider)i); + strops::replace_int32(file_content, buf_length, "{{AI_SERVICE_PROVIDER}}", i); + strops::replace(file_content, buf_length, "{{AI_SERVICE_KEY}}", service.api_key_public); + strops::replace(file_content, buf_length, "{{AI_SERVICE_MODEL}}", service.model_name); + + u32 content_len = (u32)strops::length(file_content); + memops::copy(ai_service_list_buffer+ai_service_list_buffer_cursor, file_content, content_len); + + ai_service_list_buffer_cursor += content_len; + } + + strops::replace(file_content, buf_length, "{{AI_SERVICE_LIST}}", ai_service_list_buffer); + memops::unalloc(ai_service_list_buffer); + } + email_service email_service = administration::get_email_service(); strops::replace_int32(file_content, buf_length, "{{EMAIL_SERVICE_PROVIDER}}", (s32)email_service.provider); diff --git a/src/importer.cpp b/src/importer.cpp index 65726bc..0ae0d79 100644 --- a/src/importer.cpp +++ b/src/importer.cpp @@ -27,14 +27,16 @@ #include "administration_reader.hpp" extern importer::ai_provider_impl _chatgpt_api_provider; -extern importer::ai_provider_impl _deepseek_api_provider; +extern importer::ai_provider_impl _gemini_api_provider; +extern importer::ai_provider_impl _perplexity_api_provider; importer::ai_provider_impl importer::get_ai_provider_implementation(ai_provider provider) { switch(provider) { case AI_PROVIDER_OPENAI: return _chatgpt_api_provider; - //case AI_PROVIDER_DEEPSEEK: return _deepseek_api_provider; + case AI_PROVIDER_GEMINI: return _gemini_api_provider; + case AI_PROVIDER_PERPLEXITY: return _perplexity_api_provider; default: assert(0); break; } @@ -206,7 +208,7 @@ static int _ai_document_to_invoice_t(void *arg) { importer::invoice_request* request = (importer::invoice_request*)arg; char* file_path = request->file_path; - importer::ai_provider_impl impl = importer::get_ai_provider_implementation(administration::get_ai_service().provider); + importer::ai_provider_impl impl = importer::get_ai_provider_implementation(administration::get_active_ai_service().provider); request->status = importer::status::IMPORT_UPLOADING_FILE; diff --git a/src/providers/DeepSeek.cpp b/src/providers/DeepSeek.cpp index c34e299..8a5b42e 100644 --- a/src/providers/DeepSeek.cpp +++ b/src/providers/DeepSeek.cpp @@ -30,7 +30,7 @@ static bool _DeepSeek_query_with_file(const char* query, size_t query_length, ch (void)query_length; assert(query_buffer); - const char *api_key = administration::get_ai_service().api_key_public; + const char *api_key = administration::get_active_ai_service().api_key_public; httplib::SSLClient cli("api.deepseek.com"); //cli.enable_server_certificate_verification(false); @@ -46,7 +46,7 @@ static bool _DeepSeek_query_with_file(const char* query, size_t query_length, ch size_t body_size = file_size + QUERY_BUFFER_SIZE; char* body = (char*)memops::alloc(body_size); strops::format(body, body_size, - "{\"model\":\"%s\", \"messages\": [ { \"role\": \"user\", \"content\": \"%s\" } ] }", administration::get_ai_service().model_name, query_escaped); + "{\"model\":\"%s\", \"messages\": [ { \"role\": \"user\", \"content\": \"%s\" } ] }", administration::get_active_ai_service().model_name, query_escaped); httplib::Headers headers; headers.insert(std::make_pair("Authorization", std::string("Bearer ") + api_key)); diff --git a/src/providers/openAI.cpp b/src/providers/openAI.cpp index d1495dc..6c60541 100644 --- a/src/providers/openAI.cpp +++ b/src/providers/openAI.cpp @@ -26,7 +26,7 @@ static bool _openAI_batch_query_with_file(const char** queries, size_t query_count, char* file_id, invoice* buffer, importer::batch_query_response_handler response_handler) { - const char *api_key = administration::get_ai_service().api_key_public; + const char *api_key = administration::get_active_ai_service().api_key_public; httplib::SSLClient cli("api.openai.com", 443); thrd_t threads[20]; @@ -50,7 +50,7 @@ static bool _openAI_batch_query_with_file(const char** queries, size_t query_cou " }" "], " " \"text\": { \"format\": { \"type\": \"json_object\" } } " - "}", administration::get_ai_service().model_name, file_id, query_escaped); + "}", administration::get_active_ai_service().model_name, file_id, query_escaped); httplib::Headers headers; headers.insert(std::make_pair("Authorization", std::string("Bearer ") + api_key)); @@ -96,7 +96,7 @@ static bool _openAI_batch_query_with_file(const char** queries, size_t query_cou static bool _openAI_query_with_file(const char* query, size_t query_length, char* file_id, char** response) { - const char *api_key = administration::get_ai_service().api_key_public; + const char *api_key = administration::get_active_ai_service().api_key_public; httplib::SSLClient cli("api.openai.com", 443); //cli.enable_server_certificate_verification(false); @@ -108,7 +108,7 @@ static bool _openAI_query_with_file(const char* query, size_t query_length, char char* body = (char*)memops::alloc(body_size); strops::format(body, body_size, "{\"model\":\"%s\", \"input\": [ { \"role\": \"user\", \"content\": [ { \"type\": \"input_file\", \"file_id\": \"%s\" }, " - "{ \"type\": \"input_text\", \"text\": \"%s\" } ] } ] }", administration::get_ai_service().model_name, file_id, query_escaped); + "{ \"type\": \"input_text\", \"text\": \"%s\" } ] } ] }", administration::get_active_ai_service().model_name, file_id, query_escaped); httplib::Headers headers; headers.insert(std::make_pair("Authorization", std::string("Bearer ") + api_key)); @@ -135,7 +135,7 @@ static bool _openAI_query_with_file(const char* query, size_t query_length, char static bool _openAI_upload_file(char* file_path, char* file_id, size_t file_id_len) { - const char *api_key = administration::get_ai_service().api_key_public; + const char *api_key = administration::get_active_ai_service().api_key_public; const char *filename = strops::get_filename(file_path); FILE* orig_file = fopen(file_path, "rb"); @@ -237,7 +237,7 @@ static bool _openAI_upload_file(char* file_path, char* file_id, size_t file_id_l static bool _openAI_get_available_models(importer::model_list_request* buffer) { - const char *api_key = administration::get_ai_service().api_key_public; + const char *api_key = administration::get_active_ai_service().api_key_public; httplib::SSLClient cli("api.openai.com", 443); @@ -273,4 +273,22 @@ importer::ai_provider_impl _chatgpt_api_provider = { _openAI_query_with_file, _openAI_batch_query_with_file, _openAI_get_available_models, +}; + +importer::ai_provider_impl _gemini_api_provider = { + "Gemini", + "", + 0, + 0, + 0, + 0, +}; + +importer::ai_provider_impl _perplexity_api_provider = { + "Perplexity", + "", + 0, + 0, + 0, + 0, };
\ No newline at end of file diff --git a/src/ui/ui_settings.cpp b/src/ui/ui_settings.cpp index 0d3a210..c10ca26 100644 --- a/src/ui/ui_settings.cpp +++ b/src/ui/ui_settings.cpp @@ -60,7 +60,7 @@ void ui::setup_settings() cost_centers = (cost_center*)memops::alloc(cost_center_count * sizeof(cost_center)); administration::cost_center_get_all(cost_centers); - new_ai_service = administration::get_ai_service(); + new_ai_service = administration::get_active_ai_service(); new_email_service = administration::get_email_service(); } } @@ -275,9 +275,14 @@ static void draw_ai_service_ui() { bool is_selected = n == (uint32_t)new_ai_service.provider; if (ImGui::Selectable(ai_service_names[n], is_selected)) { + + ai_service service = administration::get_ai_service((ai_provider)n); new_ai_service.provider = (ai_provider)n; + strops::copy(new_ai_service.model_name, service.model_name, MAX_LEN_SHORT_DESC); + strops::copy(new_ai_service.api_key_public, service.api_key_public, MAX_LEN_API_KEY); + model_request = 0; - set_model_on_load = true; + set_model_on_load = strops::length(new_ai_service.model_name) == 0; } } ImGui::EndCombo(); @@ -341,7 +346,7 @@ static void draw_ai_service_ui() strops::format(id, 100, "%s##ai", locale::get("form.save")); if (ImGui::Button(id, true)) { administration_writer::set_write_completed_event_callback(0); - administration::set_ai_service(new_ai_service); + administration::set_active_ai_service(new_ai_service); } } } diff --git a/tests/administration_rw_tests.cpp b/tests/administration_rw_tests.cpp index a33173f..7e80844 100644 --- a/tests/administration_rw_tests.cpp +++ b/tests/administration_rw_tests.cpp @@ -148,9 +148,9 @@ TEST _administration_rw_info(void) ss.provider = AI_PROVIDER_OPENAI; strops::copy(ss.api_key_public, "123", sizeof(ss.api_key_public)); strops::copy(ss.model_name, "321", sizeof(ss.model_name)); - administration::set_ai_service(ss); + administration::set_active_ai_service(ss); - ais = administration::get_ai_service(); + ais = administration::get_active_ai_service(); } administration_reader::open_existing(test_file_path); @@ -158,7 +158,7 @@ TEST _administration_rw_info(void) ASSERT_EQ(next_id, administration::get_next_id()); ASSERT_EQ(next_sequence_number, administration::get_next_sequence_number()); - ai_service rs = administration::get_ai_service(); + ai_service rs = administration::get_active_ai_service(); ASSERT_MEM_EQ(&ais, &rs, sizeof(ai_service)); } |
