Compare commits

...

10 Commits

Author SHA1 Message Date
dranik
163cb11319 add send/post test dart 2026-06-26 15:44:15 +03:00
dranik
71ca24d56b remove comment & format code 2026-06-26 11:38:44 +03:00
dranik
1d55372da4 fix main thread & rename files 2026-06-26 11:38:44 +03:00
dranik
f85b810efe formate code 2026-06-26 08:55:18 +03:00
dranik
70222943dd integrate project (6) 2026-06-25 18:36:51 +03:00
dranik
1fbd49f55b fix 1106 & add support add_subdirectory 2026-06-25 15:28:07 +03:00
dranik
ff8ebbfdc9 C-ABI, shared agw_capi (agw_* up), C-smoke & Dart-smoke 2026-06-25 12:55:40 +03:00
dranik
7736fa1946 async & cancel (4) 2026-06-25 09:42:13 +03:00
dranik
7e6e70d1d3 add failover (3) 2026-06-25 08:33:17 +03:00
dranik
c859128cb2 init project agw-sdk (1) & add public include (2) 2026-06-24 18:04:52 +03:00
82 changed files with 29247 additions and 794 deletions

14
agw-sdk/.gitignore vendored Normal file
View File

@@ -0,0 +1,14 @@
# Локальные сборки
build/
build-*/
cmake-build-*/
# Conan
CMakeUserPresets.json
conan.lock
# Примеры
examples/dart_smoke/.dart_tool/
examples/dart_smoke/pubspec.lock
examples/c_smoke/smoke
test_package/

133
agw-sdk/CMakeLists.txt Normal file
View File

@@ -0,0 +1,133 @@
cmake_minimum_required(VERSION 3.21)
project(agw LANGUAGES CXX VERSION 0.1.0)
# --- стандарт ---------------------------------------------------------------
# floor C++20 (см. решения в docs/plans/gateway-sdk/README.md). C++23 — opt-in позже.
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_BUILD_TYPE Release CACHE STRING "" FORCE)
endif()
# --- опции ------------------------------------------------------------------
# Когда SDK подключают через add_subdirectory (отладочная сборка в клиенте), по умолчанию НЕ строим
# тесты и НЕ строим shared C-ABI, чтобы не пачкать родительскую сборку. Standalone — всё включено.
if(PROJECT_IS_TOP_LEVEL)
set(_agw_default_aux ON)
else()
set(_agw_default_aux OFF)
endif()
option(AGW_BUILD_TESTS "Build agw-sdk tests" ${_agw_default_aux})
# Режимы зависимостей (Фаза 5): shared-deps — общий OpenSSL из Conan; vendored — бандл.
# На Фазе 1 определяем only-флаг, реальный механизм бандла появится в Фазе 5.
set(AGW_DEPS_MODE "shared-deps" CACHE STRING "Dependency mode: shared-deps | vendored")
# --- зависимости ------------------------------------------------------------
find_package(OpenSSL REQUIRED)
find_package(Threads REQUIRED)
# libcurl (транспорт по умолчанию). Если не найден — SDK собирается без него, а клиент обязан
# передать свой IHttpClient через Config (см. default_client_fallback.cpp).
find_package(CURL QUIET)
# nlohmann/json: предпочтительно из Conan (find_package), иначе — вендоренный single-header
# (он лежит в tests/third_party и нужен только для локальной сборки без Conan).
find_package(nlohmann_json QUIET)
if(NOT nlohmann_json_FOUND)
add_library(agw_nlohmann_fallback INTERFACE)
target_include_directories(agw_nlohmann_fallback INTERFACE
${CMAKE_CURRENT_SOURCE_DIR}/tests/third_party)
add_library(nlohmann_json::nlohmann_json ALIAS agw_nlohmann_fallback)
message(STATUS "agw: using vendored nlohmann/json fallback (tests/third_party)")
endif()
# --- библиотека -------------------------------------------------------------
# Один раз компилируем объекты, из них собираем static (agw) и shared C-ABI (agw_capi).
add_library(agw_obj OBJECT
src/gateway_controller.cpp
src/c_abi.cpp
src/crypto/rng.cpp
src/crypto/aes.cpp
src/crypto/rsa.cpp
src/crypto/hash.cpp
src/http/curl_client.cpp
src/http/default_client_fallback.cpp
src/protocol/request_builder.cpp
src/protocol/response.cpp
src/protocol/error_mapping.cpp
src/failover/bypass_policy.cpp
src/failover/proxy_list.cpp
src/failover/proxy_picker.cpp
src/util/base64.cpp
src/util/uuid.cpp
src/util/json.cpp
src/util/url.cpp
src/util/thread_pool.cpp
)
target_include_directories(agw_obj
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src
)
# PUBLIC, чтобы зависимости и include-пути дотекли до static/shared (и до тестов).
target_link_libraries(agw_obj
PUBLIC OpenSSL::Crypto nlohmann_json::nlohmann_json Threads::Threads
)
# Скрываем всё по умолчанию: наружу торчат только agw_* (AGW_API = visibility default).
set_target_properties(agw_obj PROPERTIES
CXX_VISIBILITY_PRESET hidden
VISIBILITY_INLINES_HIDDEN ON
POSITION_INDEPENDENT_CODE ON
)
if(CURL_FOUND)
target_compile_definitions(agw_obj PRIVATE AGW_HAVE_CURL)
target_link_libraries(agw_obj PUBLIC CURL::libcurl)
message(STATUS "agw: libcurl found — default HTTP client enabled")
else()
message(STATUS "agw: libcurl NOT found — default HTTP client disabled (inject IHttpClient via Config)")
endif()
# Статическая библиотека (C++ API) — для наших приложений / тестов.
add_library(agw STATIC)
target_link_libraries(agw PUBLIC agw_obj)
add_library(agw::agw ALIAS agw)
# Shared C-ABI библиотека (для dart:ffi и сторонних). Экспортирует только agw_*.
option(AGW_BUILD_CAPI_SHARED "Build shared C-ABI library (agw_capi)" ${_agw_default_aux})
if(AGW_BUILD_CAPI_SHARED)
add_library(agw_capi SHARED $<TARGET_OBJECTS:agw_obj>)
target_link_libraries(agw_capi PRIVATE agw_obj)
target_compile_definitions(agw_capi PRIVATE AGW_BUILDING_SHARED)
set_target_properties(agw_capi PROPERTIES OUTPUT_NAME agw_capi)
message(STATUS "agw: deps mode = ${AGW_DEPS_MODE} (vendored — статическая линковка/скрытие символов, через Conan)")
endif()
set_target_properties(agw PROPERTIES
CXX_VISIBILITY_PRESET hidden
VISIBILITY_INLINES_HIDDEN ON
POSITION_INDEPENDENT_CODE ON
)
# --- установка (только для Conan-пакета / standalone; не при add_subdirectory) ---------------
if(PROJECT_IS_TOP_LEVEL)
include(GNUInstallDirs)
install(TARGETS agw ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
if(AGW_BUILD_CAPI_SHARED)
install(TARGETS agw_capi
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
endif()
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/agw
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
endif()
# --- тесты ------------------------------------------------------------------
if(AGW_BUILD_TESTS)
enable_testing()
add_subdirectory(tests)
endif()

110
agw-sdk/README.md Normal file
View File

@@ -0,0 +1,110 @@
# agw-sdk
Qt-free C++20 транспорт к API-шлюзу Amnezia (вынос `GatewayController`). Узкая поверхность —
`post` (sync/async) поверх крипты, выбора эндпоинта и обхода блокировок. Протокол воспроизводится
байт-в-байт.
План и решения: [../docs/plans/gateway-sdk/](../docs/plans/gateway-sdk/) — начни с
`agw-sdk-tier1-impl-plan.md` и `README.md` (таблица решений).
## Статус
Тир 1, в работе по фазам:
- [x] **Фаза 1** — каркас + крипта на OpenSSL EVP (AES-256-CBC, RSA-PKCS1 v1.5, SHA-512), base64
(std + url), UUID v4, Qt-Indented JSON-сериализатор, golden-тесты крипты.
- [x] **Фаза 2**`IHttpClient`(libcurl) + `Config`/`GatewayController`/`executePost` + sync `post`;
`request_builder`/`response`/`error_mapping`; интеграционный тест через in-process mock-шлюз
(полный round-trip: SDK шифрует → «сервер» расшифровывает → шифрует ответ → SDK расшифровывает).
- [x] **Фаза 3** — failover: `bypass_policy` (`shouldBypassProxy` дословно), `proxy_list`
(S3-пути + prod-расшифровка через `SHA-512(pubkey)`), `proxy_picker` (health-check `lmbd-health`),
встройка в `executePost` с кешем рабочего прокси на инстансе (под мьютексом). Интеграционный тест:
прямой ответ подозрителен → S3 → health → прокси → успех; повторный запрос идёт сразу на кеш.
- [x] **Фаза 4**`util::ThreadPool` (drain в деструкторе), `postAsync`(коллбэк на потоке пула)/
`postFuture`(`std::future`) поверх `executePost`, `CancellationToken` (проверки между шагами
failover + прерывание трансфера через progress-коллбэк curl → `ErrorCode::Cancelled`). Кеш прокси
под мьютексом, пул — последний член Impl (рушится первым, дожидаясь задач). TSan чисто; ASan+UBSan
10/10.
- [x] **Фаза 5** — C-ABI (`include/agw/c_abi.h` + `src/c_abi.cpp`, `agw_*`: создание/уничтожение,
sync/async `post`, токен отмены, освобождение результата; на границе только C-типы). Сборка:
object-библиотека → static `agw` + shared `agw_capi` (экспортирует **только** `agw_*`, остальное
скрыто). C-smoke (чистый C) и Dart-smoke (`dart:ffi`) проходят. `conan create` (shared-deps)
зелёный — пакет с `libagw.a` + `libagw_capi.dylib` + заголовками. Режим `vendored` (статические
зависимости) задан в conanfile (`-o deps_mode=vendored`).
- [~] **Фаза 6** — интеграция в Qt-клиент. Готово: `GatewayController` переписан тонким адаптером
над `agw::GatewayController` (сигнатуры один в один, байт-паритет payload, персистентный клиент на
окружение, `onBeforeRequest` = iOS inet + desktop kill-switch, async через `QPromise`+маршалинг);
проводка сборки (корневой `conanfile` requires `agw-sdk/0.1.0`, `client/cmake/3rdparty.cmake`
линкует `agw::agw`). Осталось (вне этого окружения): Qt-сборка под все платформы, перевод
синхронных вызовов (`subscription`/`servicesCatalog` `executeRequest`) на рабочий поток, регрессия
против dev/prod. См. `docs/plans/gateway-sdk/agw-sdk-tier1-phase6-integration.md`.
## Раскладка
```
include/agw/ публичные заголовки (types, config, client, http, cancellation, c_abi)
src/crypto/ AES, RSA, SHA-512, RNG
src/util/ base64, uuid, json (Qt-Indented), url, thread_pool
src/protocol/ имена полей API, request_builder, response, error_mapping
src/failover/ bypass_policy, proxy_list, proxy_picker
src/http/ curl_client (+ fallback)
src/c_abi.cpp C-ABI обёртка
tests/ unit + golden + integration (+ вендоренный nlohmann для офлайн-сборки)
examples/ c_smoke (чистый C), dart_smoke (dart:ffi)
```
## C-ABI и потребление из Dart/C
Публичный C-заголовок — `include/agw/c_abi.h`. Shared-библиотека `libagw_capi.*` экспортирует только
`agw_*`. Примеры:
```sh
# чистый C
cc -std=c11 -Iinclude examples/c_smoke/smoke.c -Lbuild-local -lagw_capi -o /tmp/agw_smoke
DYLD_LIBRARY_PATH=build-local /tmp/agw_smoke # → код 1105, OK
# Dart (dart:ffi)
cd examples/dart_smoke && dart pub get && dart run # → код 1105, OK
```
## Локальная сборка и тесты (без Conan)
Нужны CMake ≥ 3.21 и OpenSSL 3. nlohmann/json берётся из вендоренного single-header
(`tests/third_party`), если Conan-пакет не найден.
```sh
cmake -S . -B build-local -DOPENSSL_ROOT_DIR=$(brew --prefix openssl@3)
cmake --build build-local -j
ctest --test-dir build-local --output-on-failure
```
Санитайзеры (macOS): TSan — на конкурентных тестах; ASan+UBSan — `detect_leaks=0` (LSan на Darwin
не поддержан):
```sh
cmake -S . -B build-asan -DOPENSSL_ROOT_DIR=$(brew --prefix openssl@3) \
-DCMAKE_CXX_FLAGS="-fsanitize=address,undefined -g -O1" \
-DCMAKE_EXE_LINKER_FLAGS="-fsanitize=address,undefined"
cmake --build build-asan -j
ASAN_OPTIONS=detect_leaks=0 ctest --test-dir build-asan --output-on-failure
```
## Сборка через Conan (как в проекте)
```sh
conan create . -o build_tests=True
```
Зависимости: `openssl/3.6.2`, `nlohmann_json/3.11.3` (как в корневом `conanfile.py`); `libcurl`
с Фазы 2.
## Заметки по паритету
Крипта сверена с `client/3rd/QSimpleCrypto` и `gatewayController.cpp`. Ключевое:
- AES-256-CBC, ключ 32 байта, IV генерится 32 — CBC берёт первые 16; salt (8 байт) в локальном AES
не участвует, уходит только в `key_payload`.
- RSA PKCS#1 v1.5 — паддинг рандомный, поэтому `key_payload` **не** воспроизводим байт-в-байт;
golden проверяет его round-trip, а `api_payload` (AES) — точные байты.
- JSON собирается в формате `QJsonDocument::toJson(Indented)`: отступ 4 пробела, завершающий `\n`,
**отсортированные ключи** (это даёт `aes_iv` раньше `aes_key`).

65
agw-sdk/conanfile.py Normal file
View File

@@ -0,0 +1,65 @@
from conan import ConanFile
from conan.tools.cmake import CMake, CMakeToolchain, CMakeDeps, cmake_layout
class AgwSdkConan(ConanFile):
name = "agw-sdk"
version = "0.1.0"
license = "TBD"
description = "AGW SDK — Qt-free C++ transport to the Amnezia API gateway (Tier 1)"
settings = "os", "compiler", "build_type", "arch"
# shared-deps: линкуем общий OpenSSL/curl/nlohmann из Conan (наши приложения).
# vendored: бандлим зависимости статически + скрытие символов (сторонние/standalone).
options = {
"deps_mode": ["shared-deps", "vendored"],
"build_tests": [True, False],
"build_capi_shared": [True, False],
}
default_options = {
"deps_mode": "shared-deps",
"build_tests": False,
"build_capi_shared": True,
}
exports_sources = "CMakeLists.txt", "include/*", "src/*", "tests/*"
def requirements(self):
# Версия OpenSSL совпадает с приложением (корневой conanfile.py) — без второго OpenSSL.
self.requires("openssl/3.6.2")
self.requires("libcurl/8.10.1")
self.requires("nlohmann_json/3.11.3")
def configure(self):
# vendored: тянем статические зависимости, чтобы забандлить их в библиотеку.
if self.options.deps_mode == "vendored":
self.options["openssl"].shared = False
self.options["libcurl"].shared = False
def layout(self):
cmake_layout(self)
def generate(self):
deps = CMakeDeps(self)
deps.generate()
tc = CMakeToolchain(self)
tc.variables["AGW_DEPS_MODE"] = str(self.options.deps_mode)
tc.variables["AGW_BUILD_TESTS"] = bool(self.options.build_tests)
tc.variables["AGW_BUILD_CAPI_SHARED"] = bool(self.options.build_capi_shared)
tc.generate()
def build(self):
cmake = CMake(self)
cmake.configure()
cmake.build()
def package(self):
cmake = CMake(self)
cmake.install()
def package_info(self):
self.cpp_info.libs = ["agw"]
self.cpp_info.includedirs = ["include"]
# Потребитель подключает: find_package(agw-sdk) + target agw::agw
self.cpp_info.set_property("cmake_file_name", "agw-sdk")
self.cpp_info.set_property("cmake_target_name", "agw::agw")

View File

@@ -0,0 +1,39 @@
/*
* Чистый C-потребитель C-ABI: доказывает, что agw_* линкуется и работает из C без C++/Qt.
* Детерминированный путь без сети: невалидный публичный ключ → ApiMissingAgwPublicKey (1105).
*
* Сборка (пример, macOS):
* cc -std=c11 -I ../../include smoke.c -L ../../build-local -lagw_capi -o smoke
* DYLD_LIBRARY_PATH=../../build-local ./smoke
*/
#include <stdio.h>
#include <string.h>
#include "agw/c_abi.h"
int main(void)
{
agw_config cfg;
memset(&cfg, 0, sizeof(cfg));
cfg.gateway_endpoint = "gw.example.test";
cfg.agw_public_key_pem = "not a real pem key"; /* → 1105 без обращения к сети */
cfg.request_timeout_msecs = 5000;
agw_client *client = agw_client_create(&cfg);
if (client == NULL) {
printf("FAIL: agw_client_create returned NULL\n");
return 1;
}
agw_response r = agw_client_post(client, "https://%1/api/v1/test", "{\"x\":1}", "", "", NULL);
printf("post error code = %d\n", r.error);
int ok = (r.error == 1105); /* ApiMissingAgwPublicKey */
agw_response_free(&r);
agw_client_destroy(client);
printf(ok ? "OK\n" : "FAIL\n");
return ok ? 0 : 1;
}

View File

@@ -0,0 +1,235 @@
// Dart-демо C-ABI agw-sdk через dart:ffi.
//
// Показывает поток запроса: подключает лог-хук SDK и onBeforeRequest, поэтому печатаются строки
// [agw] (post START -> direct request url -> direct response -> failover -> post DONE) — видно,
// что запрос ушёл и ответ пришёл. Делает синхронный post, а при AGW_ASYNC=1 — ещё и асинхронный
// (agw_client_post_async + NativeCallable.listener: коллбэк прилетает с потока пула SDK).
//
// Конфиг через переменные окружения (все опциональны):
// AGW_GATEWAY хост шлюза с "%1"-подстановкой, напр. "http://gw.dev.amzsvc.com:80/"
// AGW_PUBKEY_FILE путь к PEM ПУБЛИЧНОГО ключа шлюза (по умолчанию — тестовый фикстур)
// AGW_S3_PRIMARY список S3-адресов через запятую (failover)
// AGW_DEV "1" → dev-режим (S3-список открытым текстом)
// AGW_ENDPOINT шаблон пути, по умолчанию "%1v1/services"
// AGW_PAYLOAD JSON тела запроса
// AGW_ASYNC "1" → дополнительно прогнать асинхронный вызов
// AGW_CAPI_LIB путь к libagw_capi.* (иначе ../../build-local)
import 'dart:async';
import 'dart:ffi';
import 'dart:io';
import 'package:ffi/ffi.dart';
final class AgwConfig extends Struct {
external Pointer<Utf8> gatewayEndpoint;
external Pointer<Utf8> agwPublicKeyPem;
external Pointer<Pointer<Utf8>> s3Primary;
@Size()
external int s3PrimaryCount;
external Pointer<Pointer<Utf8>> s3Fallback;
@Size()
external int s3FallbackCount;
@Int32()
external int isDevEnvironment;
@Int32()
external int requestTimeoutMsecs;
@Int32()
external int proxyHealthTimeoutMsecs;
@Int32()
external int proxyStorageTimeoutMsecs;
@Int32()
external int threadPoolSize;
external Pointer<Void> onBeforeRequest;
external Pointer<Void> onBeforeRequestUserData;
external Pointer<Void> log;
external Pointer<Void> logUserData;
}
final class AgwResponse extends Struct {
@Int32()
external int error;
external Pointer<Utf8> body;
@Size()
external int bodyLen;
}
typedef _CreateC = Pointer<Void> Function(Pointer<AgwConfig>);
typedef _PostC = AgwResponse Function(Pointer<Void>, Pointer<Utf8>, Pointer<Utf8>,
Pointer<Utf8>, Pointer<Utf8>, Pointer<Void>);
typedef _PostAsyncC = Void Function(Pointer<Void>, Pointer<Utf8>, Pointer<Utf8>,
Pointer<Utf8>, Pointer<Utf8>, Pointer<Void>, Pointer<Void>, Pointer<Void>);
typedef _PostAsyncDart = void Function(Pointer<Void>, Pointer<Utf8>, Pointer<Utf8>,
Pointer<Utf8>, Pointer<Utf8>, Pointer<Void>, Pointer<Void>, Pointer<Void>);
typedef _FreeC = Void Function(Pointer<AgwResponse>);
typedef _FreeDart = void Function(Pointer<AgwResponse>);
typedef _DestroyC = Void Function(Pointer<Void>);
typedef _DestroyDart = void Function(Pointer<Void>);
typedef _LogNative = Void Function(Int32, Pointer<Utf8>, Pointer<Void>);
typedef _BeforeNative = Void Function(Pointer<Utf8>, Pointer<Void>);
typedef _PostCbNative = Void Function(AgwResponse, Pointer<Void>);
const _levels = ['DBG', 'INF', 'WRN', 'ERR'];
void _printLog(String tag, int level, Pointer<Utf8> message) {
final lvl = (level >= 0 && level < _levels.length) ? _levels[level] : '?';
stdout.writeln(' $tag [agw][$lvl] ${message.toDartString()}');
}
class _Cfg {
final Pointer<AgwConfig> ptr;
final List<Pointer<NativeType>> allocs;
_Cfg(this.ptr, this.allocs);
void free() {
for (final p in allocs) {
calloc.free(p);
}
calloc.free(ptr);
}
}
String _libPath() {
final env = Platform.environment['AGW_CAPI_LIB'];
if (env != null) return env;
final base = '${Directory.current.path}/../../build-local';
if (Platform.isMacOS) return '$base/libagw_capi.dylib';
if (Platform.isWindows) return '$base/agw_capi.dll';
return '$base/libagw_capi.so';
}
String _defaultPubKey() {
final f = File('${Directory.current.path}/../../tests/golden/fixtures/test_rsa_pub.pem');
return f.existsSync() ? f.readAsStringSync() : 'not a real pem key';
}
late String gateway, pubKey, payload;
late int isDev;
late List<String> s3List;
_Cfg buildConfig(Pointer<Void> logFn, Pointer<Void> beforeFn) {
final cfg = calloc<AgwConfig>();
final allocs = <Pointer<NativeType>>[];
final gw = gateway.toNativeUtf8();
final pk = pubKey.toNativeUtf8();
allocs.add(gw);
allocs.add(pk);
cfg.ref.gatewayEndpoint = gw;
cfg.ref.agwPublicKeyPem = pk;
cfg.ref.requestTimeoutMsecs = 8000;
cfg.ref.isDevEnvironment = isDev;
cfg.ref.onBeforeRequest = beforeFn;
cfg.ref.log = logFn;
if (s3List.isNotEmpty) {
final arr = calloc<Pointer<Utf8>>(s3List.length);
for (var i = 0; i < s3List.length; i++) {
arr[i] = s3List[i].toNativeUtf8();
allocs.add(arr[i]);
}
cfg.ref.s3Primary = arr;
cfg.ref.s3PrimaryCount = s3List.length;
allocs.add(arr);
}
return _Cfg(cfg, allocs);
}
Future<int> main() async {
final env = Platform.environment;
final lib = DynamicLibrary.open(_libPath());
final create = lib.lookupFunction<_CreateC, _CreateC>('agw_client_create');
final post = lib.lookupFunction<_PostC, _PostC>('agw_client_post');
final postAsync = lib.lookupFunction<_PostAsyncC, _PostAsyncDart>('agw_client_post_async');
final free = lib.lookupFunction<_FreeC, _FreeDart>('agw_response_free');
final destroy = lib.lookupFunction<_DestroyC, _DestroyDart>('agw_client_destroy');
gateway = env['AGW_GATEWAY'] ?? 'http://gw.example.test/';
final pubKeyFile = env['AGW_PUBKEY_FILE'];
pubKey = pubKeyFile != null ? File(pubKeyFile).readAsStringSync() : _defaultPubKey();
final endpoint = env['AGW_ENDPOINT'] ?? '%1v1/services';
payload = env['AGW_PAYLOAD'] ??
'{"os_version":"macos","app_version":"4.9.0","cli_name":"amnezia","app_language":"en"}';
isDev = (env['AGW_DEV'] == '1') ? 1 : 0;
s3List = (env['AGW_S3_PRIMARY'] ?? '')
.split(',')
.map((s) => s.trim())
.where((s) => s.isNotEmpty)
.toList();
stdout.writeln('=== agw-sdk Dart demo ===');
stdout.writeln('gateway=$gateway endpoint=$endpoint dev=$isDev s3primary=${s3List.length}');
stdout.writeln('pubkey=${pubKeyFile ?? "(test fixture)"}');
final endpointC = endpoint.toNativeUtf8();
final payloadC = payload.toNativeUtf8();
final svc = ''.toNativeUtf8();
// ---------- SYNC (коллбэки isolateLocal: ядро sync исполняется на этом потоке) ----------
stdout.writeln('\n--- SYNC post ---');
final logSync = NativeCallable<_LogNative>.isolateLocal(
(int lvl, Pointer<Utf8> m, Pointer<Void> _) => _printLog('[sync]', lvl, m));
final beforeSync = NativeCallable<_BeforeNative>.isolateLocal(
(Pointer<Utf8> h, Pointer<Void> _) =>
stdout.writeln(' [sync] → onBeforeRequest host=${h.toDartString()}'));
final cfgSync = buildConfig(logSync.nativeFunction.cast(), beforeSync.nativeFunction.cast());
final clientSync = create(cfgSync.ptr);
final resp = post(clientSync, endpointC, payloadC, svc, svc, nullptr);
stdout.writeln(' [sync] RESULT errorCode=${resp.error} bodyLen=${resp.bodyLen}');
if (resp.body != nullptr && resp.bodyLen > 0) {
final body = resp.body.toDartString(length: resp.bodyLen);
stdout.writeln(' [sync] body=${body.length > 200 ? "${body.substring(0, 200)}…" : body}');
}
final rp = calloc<AgwResponse>()
..ref.error = resp.error
..ref.body = resp.body
..ref.bodyLen = resp.bodyLen;
free(rp);
calloc.free(rp);
destroy(clientSync);
cfgSync.free();
logSync.close();
beforeSync.close();
// ---------- ASYNC (коллбэки listener: прилетают с потока пула SDK) ----------
if (env['AGW_ASYNC'] == '1') {
stdout.writeln('\n--- ASYNC post (коллбэк с потока пула) ---');
// ВАЖНО: лог-хук/onBeforeRequest сюда НЕ вешаем. Их const char* живут лишь во время вызова на
// потоке пула, а NativeCallable.listener выполняется позже на Dart event-loop → указатель был бы
// висячим. Result-коллбэк безопасен: body выделен в куче и принадлежит вызывающему (мы его освобождаем).
final done = Completer<void>();
final resultCb = NativeCallable<_PostCbNative>.listener((AgwResponse r, Pointer<Void> _) {
stdout.writeln(' [async] CALLBACK (прилетел с потока пула) errorCode=${r.error} bodyLen=${r.bodyLen}');
if (r.body != nullptr && r.bodyLen > 0) {
final body = r.body.toDartString(length: r.bodyLen);
stdout.writeln(' [async] body=${body.length > 200 ? "${body.substring(0, 200)}…" : body}');
}
final p = calloc<AgwResponse>()
..ref.error = r.error
..ref.body = r.body
..ref.bodyLen = r.bodyLen;
free(p);
calloc.free(p);
done.complete();
});
final cfgAsync = buildConfig(nullptr, nullptr); // без лог/before-хуков (см. выше)
final clientAsync = create(cfgAsync.ptr);
stdout.writeln(' [async] post_async отправлен, ждём коллбэк…');
postAsync(clientAsync, endpointC, payloadC, svc, svc, resultCb.nativeFunction.cast(),
nullptr, nullptr);
await done.future; // ждём, пока коллбэк прилетит с потока пула
destroy(clientAsync);
cfgAsync.free();
resultCb.close();
}
calloc.free(endpointC);
calloc.free(payloadC);
calloc.free(svc);
stdout.writeln('\n=== done ===');
return 0;
}

View File

@@ -0,0 +1,9 @@
name: agw_dart_smoke
description: Smoke-вызов C-ABI agw-sdk через dart:ffi.
publish_to: none
environment:
sdk: ">=3.0.0 <4.0.0"
dependencies:
ffi: ^2.1.0

10
agw-sdk/include/agw/agw.h Normal file
View File

@@ -0,0 +1,10 @@
#ifndef AGW_AGW_H
#define AGW_AGW_H
#include "agw/cancellation.h"
#include "agw/gateway_controller.h"
#include "agw/config.h"
#include "agw/http.h"
#include "agw/types.h"
#endif

View File

@@ -0,0 +1,80 @@
#ifndef AGW_C_ABI_H
#define AGW_C_ABI_H
#include <stddef.h>
#ifdef __cplusplus
extern "C" {
#endif
#if defined(_WIN32)
#if defined(AGW_BUILDING_SHARED)
#define AGW_API __declspec(dllexport)
#elif defined(AGW_USING_SHARED)
#define AGW_API __declspec(dllimport)
#else
#define AGW_API
#endif
#else
#define AGW_API __attribute__((visibility("default")))
#endif
typedef struct agw_client agw_client;
typedef struct agw_cancel_token agw_cancel_token;
typedef void (*agw_before_request_fn)(const char *host, void *user_data);
typedef void (*agw_log_fn)(int level, const char *message, void *user_data);
typedef struct
{
const char *gateway_endpoint;
const char *agw_public_key_pem;
const char *const *s3_primary_endpoints;
size_t s3_primary_count;
const char *const *s3_fallback_endpoints;
size_t s3_fallback_count;
int is_dev_environment;
int request_timeout_msecs;
int proxy_health_timeout_msecs;
int proxy_storage_timeout_msecs;
int thread_pool_size;
agw_before_request_fn on_before_request;
void *on_before_request_user_data;
agw_log_fn log;
void *log_user_data;
} agw_config;
typedef struct
{
int error;
char *body;
size_t body_len;
} agw_response;
typedef void (*agw_post_callback)(agw_response response, void *user_data);
AGW_API agw_client *agw_client_create(const agw_config *config);
AGW_API void agw_client_destroy(agw_client *client);
AGW_API agw_response agw_client_post(agw_client *client, const char *endpoint, const char *payload,
const char *service_type, const char *user_country_code,
agw_cancel_token *cancel_token);
AGW_API void agw_client_post_async(agw_client *client, const char *endpoint, const char *payload,
const char *service_type, const char *user_country_code, agw_post_callback callback,
void *user_data, agw_cancel_token *cancel_token);
AGW_API void agw_response_free(agw_response *response);
AGW_API agw_cancel_token *agw_cancel_token_create(void);
AGW_API void agw_cancel_token_cancel(agw_cancel_token *token);
AGW_API void agw_cancel_token_destroy(agw_cancel_token *token);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,30 @@
#ifndef AGW_CANCELLATION_H
#define AGW_CANCELLATION_H
#include <atomic>
namespace agw
{
class CancellationToken
{
public:
CancellationToken() = default;
CancellationToken(const CancellationToken &) = delete;
CancellationToken &operator=(const CancellationToken &) = delete;
void cancel() noexcept
{
m_cancelled.store(true, std::memory_order_relaxed);
}
bool isCancelled() const noexcept
{
return m_cancelled.load(std::memory_order_relaxed);
}
private:
std::atomic<bool> m_cancelled { false };
};
}
#endif

View File

@@ -0,0 +1,38 @@
#ifndef AGW_CONFIG_H
#define AGW_CONFIG_H
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "agw/http.h"
#include "agw/types.h"
namespace agw
{
struct Config
{
std::string gatewayEndpoint;
std::string agwPublicKeyPem;
std::vector<std::string> s3PrimaryEndpoints;
std::vector<std::string> s3FallbackEndpoints;
bool isDevEnvironment = false;
int requestTimeoutMsecs = 12000;
int proxyHealthTimeoutMsecs = 1000;
int proxyStorageTimeoutMsecs = 3000;
int threadPoolSize = 4;
std::function<void(const std::string &host)> onBeforeRequest;
std::function<void(LogLevel, const std::string &message)> log;
std::shared_ptr<IHttpClient> httpClient;
};
}
#endif

View File

@@ -0,0 +1,40 @@
#ifndef AGW_GATEWAY_CONTROLLER_H
#define AGW_GATEWAY_CONTROLLER_H
#include <functional>
#include <future>
#include <memory>
#include <string>
#include "agw/cancellation.h"
#include "agw/config.h"
#include "agw/types.h"
namespace agw {
class GatewayController {
public:
explicit GatewayController(Config config);
~GatewayController();
GatewayController(GatewayController &&) noexcept;
GatewayController &operator=(GatewayController &&) noexcept;
GatewayController(const GatewayController &) = delete;
GatewayController &operator=(const GatewayController &) = delete;
Response post(const std::string &endpoint, const std::string &payload, const FailoverContext &ctx,
CancellationToken *cancel = nullptr);
void postAsync(const std::string &endpoint, const std::string &payload,
std::function<void(Response)> onResult, const FailoverContext &ctx,
CancellationToken *cancel = nullptr);
std::future<Response> postFuture(const std::string &endpoint, const std::string &payload,
const FailoverContext &ctx, CancellationToken *cancel = nullptr);
private:
struct Impl;
std::unique_ptr<Impl> m_impl;
};
}
#endif

View File

@@ -0,0 +1,51 @@
#ifndef AGW_HTTP_H
#define AGW_HTTP_H
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace agw
{
enum class TransportError {
None = 0,
Timeout,
Canceled,
OperationNotImplemented,
ConnectionError,
};
struct HttpRequest
{
std::string url;
std::string method;
std::string body;
std::vector<std::pair<std::string, std::string>> headers;
int timeoutMsecs = 0;
std::function<bool()> cancelCheck;
};
struct HttpResponse
{
TransportError error = TransportError::None;
std::string errorString;
int httpStatusCode = 0;
bool sslError = false;
std::string body;
};
class IHttpClient
{
public:
virtual ~IHttpClient() = default;
virtual HttpResponse send(const HttpRequest &request) = 0;
};
std::unique_ptr<IHttpClient> makeDefaultHttpClient();
}
#endif

View File

@@ -0,0 +1,55 @@
#ifndef AGW_TYPES_H
#define AGW_TYPES_H
#include <string>
namespace agw
{
enum class ErrorCode : int {
NoError = 0,
Cancelled = 1,
ApiConfigDownloadError = 1100,
ApiConfigAlreadyAdded = 1101,
ApiConfigEmptyError = 1102,
ApiConfigTimeoutError = 1103,
ApiConfigSslError = 1104,
ApiMissingAgwPublicKey = 1105,
ApiConfigDecryptionError = 1106,
ApiServicesMissingError = 1107,
ApiConfigLimitError = 1108,
ApiNotFoundError = 1109,
ApiMigrationError = 1110,
ApiUpdateRequestError = 1111,
ApiSubscriptionExpiredError = 1112,
ApiPurchaseError = 1113,
ApiSubscriptionNotActiveError = 1114,
ApiNoPurchasedSubscriptionsError = 1115,
ApiTrialAlreadyUsedError = 1116,
ApiCaptchaRequiredError = 1117,
ApiCaptchaInvalidError = 1118,
ApiCaptchaRefreshError = 1119,
ApiRateLimitError = 1120,
};
enum class LogLevel : int {
Debug,
Info,
Warning,
Error
};
struct Response
{
ErrorCode error = ErrorCode::NoError;
std::string body;
};
struct FailoverContext
{
std::string serviceType;
std::string userCountryCode;
};
}
#endif

201
agw-sdk/src/c_abi.cpp Normal file
View File

@@ -0,0 +1,201 @@
#include "agw/c_abi.h"
#include <cstdlib>
#include <cstring>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include "agw/cancellation.h"
#include "agw/gateway_controller.h"
#include "agw/config.h"
#include "detail/test_hooks.h"
struct agw_client
{
explicit agw_client(agw::Config cfg) : client(std::move(cfg))
{
}
agw::GatewayController client;
};
struct agw_cancel_token
{
agw::CancellationToken token;
};
namespace agw::detail
{
namespace
{
std::mutex g_testHttpMutex;
std::shared_ptr<IHttpClient> g_testHttp;
}
void setNextTestHttpClient(std::shared_ptr<IHttpClient> http)
{
std::lock_guard<std::mutex> lock(g_testHttpMutex);
g_testHttp = std::move(http);
}
std::shared_ptr<IHttpClient> takeNextTestHttpClient()
{
std::lock_guard<std::mutex> lock(g_testHttpMutex);
std::shared_ptr<IHttpClient> h = std::move(g_testHttp);
g_testHttp.reset();
return h;
}
}
namespace
{
std::string cstr(const char *s)
{
return s ? std::string(s) : std::string();
}
agw_response toCResponse(const agw::Response &r)
{
agw_response out;
out.error = static_cast<int>(r.error);
out.body = nullptr;
out.body_len = r.body.size();
char *buf = static_cast<char *>(std::malloc(r.body.size() + 1));
if (buf != nullptr) {
if (!r.body.empty()) {
std::memcpy(buf, r.body.data(), r.body.size());
}
buf[r.body.size()] = '\0';
out.body = buf;
} else {
out.body_len = 0;
}
return out;
}
}
extern "C" {
agw_client *agw_client_create(const agw_config *config)
{
if (config == nullptr) {
return nullptr;
}
agw::Config cfg;
cfg.gatewayEndpoint = cstr(config->gateway_endpoint);
cfg.agwPublicKeyPem = cstr(config->agw_public_key_pem);
for (size_t i = 0; i < config->s3_primary_count; ++i) {
cfg.s3PrimaryEndpoints.push_back(cstr(config->s3_primary_endpoints[i]));
}
for (size_t i = 0; i < config->s3_fallback_count; ++i) {
cfg.s3FallbackEndpoints.push_back(cstr(config->s3_fallback_endpoints[i]));
}
cfg.isDevEnvironment = config->is_dev_environment != 0;
if (config->request_timeout_msecs > 0)
cfg.requestTimeoutMsecs = config->request_timeout_msecs;
if (config->proxy_health_timeout_msecs > 0)
cfg.proxyHealthTimeoutMsecs = config->proxy_health_timeout_msecs;
if (config->proxy_storage_timeout_msecs > 0)
cfg.proxyStorageTimeoutMsecs = config->proxy_storage_timeout_msecs;
if (config->thread_pool_size > 0)
cfg.threadPoolSize = config->thread_pool_size;
if (config->on_before_request != nullptr) {
agw_before_request_fn fn = config->on_before_request;
void *ud = config->on_before_request_user_data;
cfg.onBeforeRequest = [fn, ud](const std::string &host) { fn(host.c_str(), ud); };
}
if (config->log != nullptr) {
agw_log_fn fn = config->log;
void *ud = config->log_user_data;
cfg.log = [fn, ud](agw::LogLevel level, const std::string &msg) { fn(static_cast<int>(level), msg.c_str(), ud); };
}
if (auto http = agw::detail::takeNextTestHttpClient()) {
cfg.httpClient = http;
}
try {
return new agw_client(std::move(cfg));
} catch (...) {
return nullptr;
}
}
void agw_client_destroy(agw_client *client)
{
delete client;
}
agw_response agw_client_post(agw_client *client, const char *endpoint, const char *payload, const char *service_type,
const char *user_country_code, agw_cancel_token *cancel_token)
{
if (client == nullptr) {
agw_response out;
out.error = static_cast<int>(agw::ErrorCode::ApiConfigDownloadError);
out.body = nullptr;
out.body_len = 0;
return out;
}
agw::FailoverContext ctx { cstr(service_type), cstr(user_country_code) };
agw::CancellationToken *tk = cancel_token ? &cancel_token->token : nullptr;
agw::Response r = client->client.post(cstr(endpoint), cstr(payload), ctx, tk);
return toCResponse(r);
}
void agw_client_post_async(agw_client *client, const char *endpoint, const char *payload, const char *service_type,
const char *user_country_code, agw_post_callback callback, void *user_data,
agw_cancel_token *cancel_token)
{
if (client == nullptr || callback == nullptr) {
return;
}
agw::FailoverContext ctx { cstr(service_type), cstr(user_country_code) };
agw::CancellationToken *tk = cancel_token ? &cancel_token->token : nullptr;
client->client.postAsync(
cstr(endpoint), cstr(payload),
[callback, user_data](agw::Response r) {
agw_response cr = toCResponse(r);
callback(cr, user_data);
},
ctx, tk);
}
void agw_response_free(agw_response *response)
{
if (response == nullptr) {
return;
}
std::free(response->body);
response->body = nullptr;
response->body_len = 0;
}
agw_cancel_token *agw_cancel_token_create(void)
{
try {
return new agw_cancel_token();
} catch (...) {
return nullptr;
}
}
void agw_cancel_token_cancel(agw_cancel_token *token)
{
if (token != nullptr) {
token->token.cancel();
}
}
void agw_cancel_token_destroy(agw_cancel_token *token)
{
delete token;
}
}

View File

@@ -0,0 +1,87 @@
#include "aes.h"
#include <memory>
#include <stdexcept>
#include <openssl/evp.h>
namespace agw::crypto
{
namespace
{
using CtxPtr = std::unique_ptr<EVP_CIPHER_CTX, decltype(&EVP_CIPHER_CTX_free)>;
constexpr int kAes256KeyLen = 32;
constexpr int kAesBlock = 16;
void checkKeyIv(const std::vector<std::uint8_t> &key, const std::vector<std::uint8_t> &iv)
{
if (key.size() != static_cast<std::size_t>(kAes256KeyLen)) {
throw std::runtime_error("agw::crypto::aes: key must be 32 bytes (AES-256)");
}
if (iv.size() < static_cast<std::size_t>(kAesBlock)) {
throw std::runtime_error("agw::crypto::aes: iv must be at least 16 bytes");
}
}
}
std::vector<std::uint8_t> aesEncryptCbc(const std::vector<std::uint8_t> &data, const std::vector<std::uint8_t> &key,
const std::vector<std::uint8_t> &iv)
{
checkKeyIv(key, iv);
CtxPtr ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free);
if (!ctx) {
throw std::runtime_error("agw::crypto::aes: EVP_CIPHER_CTX_new failed");
}
if (EVP_EncryptInit_ex(ctx.get(), EVP_aes_256_cbc(), nullptr, key.data(), iv.data()) != 1) {
throw std::runtime_error("agw::crypto::aes: EVP_EncryptInit_ex failed");
}
std::vector<std::uint8_t> out(data.size() + kAesBlock);
int len = 0;
if (EVP_EncryptUpdate(ctx.get(), out.data(), &len, data.data(), static_cast<int>(data.size())) != 1) {
throw std::runtime_error("agw::crypto::aes: EVP_EncryptUpdate failed");
}
int total = len;
if (EVP_EncryptFinal_ex(ctx.get(), out.data() + total, &len) != 1) {
throw std::runtime_error("agw::crypto::aes: EVP_EncryptFinal_ex failed");
}
total += len;
out.resize(static_cast<std::size_t>(total));
return out;
}
std::vector<std::uint8_t> aesDecryptCbc(const std::vector<std::uint8_t> &data, const std::vector<std::uint8_t> &key,
const std::vector<std::uint8_t> &iv)
{
checkKeyIv(key, iv);
CtxPtr ctx(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free);
if (!ctx) {
throw std::runtime_error("agw::crypto::aes: EVP_CIPHER_CTX_new failed");
}
if (EVP_DecryptInit_ex(ctx.get(), EVP_aes_256_cbc(), nullptr, key.data(), iv.data()) != 1) {
throw std::runtime_error("agw::crypto::aes: EVP_DecryptInit_ex failed");
}
std::vector<std::uint8_t> out(data.size() + kAesBlock);
int len = 0;
if (EVP_DecryptUpdate(ctx.get(), out.data(), &len, data.data(), static_cast<int>(data.size())) != 1) {
throw std::runtime_error("agw::crypto::aes: EVP_DecryptUpdate failed");
}
int total = len;
if (EVP_DecryptFinal_ex(ctx.get(), out.data() + total, &len) != 1) {
throw std::runtime_error("agw::crypto::aes: EVP_DecryptFinal_ex failed (bad key/iv/padding)");
}
total += len;
out.resize(static_cast<std::size_t>(total));
return out;
}
}

16
agw-sdk/src/crypto/aes.h Normal file
View File

@@ -0,0 +1,16 @@
#ifndef AGW_CRYPTO_AES_H
#define AGW_CRYPTO_AES_H
#include <cstdint>
#include <vector>
namespace agw::crypto
{
std::vector<std::uint8_t> aesEncryptCbc(const std::vector<std::uint8_t> &data, const std::vector<std::uint8_t> &key,
const std::vector<std::uint8_t> &iv);
std::vector<std::uint8_t> aesDecryptCbc(const std::vector<std::uint8_t> &data, const std::vector<std::uint8_t> &key,
const std::vector<std::uint8_t> &iv);
}
#endif

View File

@@ -0,0 +1,59 @@
#include "hash.h"
#include <stdexcept>
#include <openssl/sha.h>
namespace agw::crypto
{
namespace
{
int hexNibble(char c)
{
if (c >= '0' && c <= '9')
return c - '0';
if (c >= 'a' && c <= 'f')
return c - 'a' + 10;
if (c >= 'A' && c <= 'F')
return c - 'A' + 10;
return -1;
}
}
std::vector<std::uint8_t> sha512(const std::vector<std::uint8_t> &data)
{
std::vector<std::uint8_t> out(SHA512_DIGEST_LENGTH);
SHA512(data.data(), data.size(), out.data());
return out;
}
std::string toHex(const std::vector<std::uint8_t> &data)
{
static const char *digits = "0123456789abcdef";
std::string out;
out.reserve(data.size() * 2);
for (std::uint8_t b : data) {
out.push_back(digits[b >> 4]);
out.push_back(digits[b & 0x0F]);
}
return out;
}
std::vector<std::uint8_t> fromHex(const std::string &hex)
{
if (hex.size() % 2 != 0) {
throw std::runtime_error("agw::crypto::fromHex: odd-length input");
}
std::vector<std::uint8_t> out;
out.reserve(hex.size() / 2);
for (std::size_t i = 0; i < hex.size(); i += 2) {
const int hi = hexNibble(hex[i]);
const int lo = hexNibble(hex[i + 1]);
if (hi < 0 || lo < 0) {
throw std::runtime_error("agw::crypto::fromHex: invalid hex character");
}
out.push_back(static_cast<std::uint8_t>((hi << 4) | lo));
}
return out;
}
}

17
agw-sdk/src/crypto/hash.h Normal file
View File

@@ -0,0 +1,17 @@
#ifndef AGW_CRYPTO_HASH_H
#define AGW_CRYPTO_HASH_H
#include <cstdint>
#include <string>
#include <vector>
namespace agw::crypto
{
std::vector<std::uint8_t> sha512(const std::vector<std::uint8_t> &data);
std::string toHex(const std::vector<std::uint8_t> &data);
std::vector<std::uint8_t> fromHex(const std::string &hex);
}
#endif

View File

@@ -0,0 +1,20 @@
#include "rng.h"
#include <stdexcept>
#include <openssl/rand.h>
namespace agw::crypto
{
std::vector<std::uint8_t> DefaultRng::bytes(std::size_t n)
{
std::vector<std::uint8_t> out(n);
if (n == 0) {
return out;
}
if (RAND_priv_bytes(out.data(), static_cast<int>(n)) != 1) {
throw std::runtime_error("agw::crypto::DefaultRng: RAND_priv_bytes failed");
}
return out;
}
}

24
agw-sdk/src/crypto/rng.h Normal file
View File

@@ -0,0 +1,24 @@
#ifndef AGW_CRYPTO_RNG_H
#define AGW_CRYPTO_RNG_H
#include <cstddef>
#include <cstdint>
#include <vector>
namespace agw::crypto
{
class IRng
{
public:
virtual ~IRng() = default;
virtual std::vector<std::uint8_t> bytes(std::size_t n) = 0;
};
class DefaultRng : public IRng
{
public:
std::vector<std::uint8_t> bytes(std::size_t n) override;
};
}
#endif

111
agw-sdk/src/crypto/rsa.cpp Normal file
View File

@@ -0,0 +1,111 @@
#include "rsa.h"
#include <memory>
#include <stdexcept>
#include <openssl/bio.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include <openssl/rsa.h>
namespace agw::crypto
{
namespace
{
using BioPtr = std::unique_ptr<BIO, decltype(&BIO_free)>;
using PkeyPtr = std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)>;
using PkeyCtxPtr = std::unique_ptr<EVP_PKEY_CTX, decltype(&EVP_PKEY_CTX_free)>;
PkeyPtr loadPublicKey(const std::string &pem)
{
BioPtr bio(BIO_new_mem_buf(pem.data(), static_cast<int>(pem.size())), BIO_free);
if (!bio) {
throw std::runtime_error("agw::crypto::rsa: BIO_new_mem_buf failed");
}
EVP_PKEY *raw = nullptr;
if (!PEM_read_bio_PUBKEY(bio.get(), &raw, nullptr, nullptr)) {
throw std::runtime_error("agw::crypto::rsa: PEM_read_bio_PUBKEY failed");
}
return PkeyPtr(raw, EVP_PKEY_free);
}
PkeyPtr loadPrivateKey(const std::string &pem)
{
BioPtr bio(BIO_new_mem_buf(pem.data(), static_cast<int>(pem.size())), BIO_free);
if (!bio) {
throw std::runtime_error("agw::crypto::rsa: BIO_new_mem_buf failed");
}
EVP_PKEY *raw = nullptr;
if (!PEM_read_bio_PrivateKey(bio.get(), &raw, nullptr, nullptr)) {
throw std::runtime_error("agw::crypto::rsa: PEM_read_bio_PrivateKey failed");
}
return PkeyPtr(raw, EVP_PKEY_free);
}
}
std::vector<std::uint8_t> rsaEncryptPublicPkcs1(const std::vector<std::uint8_t> &plaintext,
const std::string &publicKeyPem)
{
PkeyPtr key = loadPublicKey(publicKeyPem);
PkeyCtxPtr ctx(EVP_PKEY_CTX_new(key.get(), nullptr), EVP_PKEY_CTX_free);
if (!ctx) {
throw std::runtime_error("agw::crypto::rsa: EVP_PKEY_CTX_new failed");
}
if (EVP_PKEY_encrypt_init(ctx.get()) != 1) {
throw std::runtime_error("agw::crypto::rsa: EVP_PKEY_encrypt_init failed");
}
if (EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_PADDING) != 1) {
throw std::runtime_error("agw::crypto::rsa: set_rsa_padding failed");
}
std::size_t outLen = 0;
if (EVP_PKEY_encrypt(ctx.get(), nullptr, &outLen, plaintext.data(), plaintext.size()) != 1) {
throw std::runtime_error("agw::crypto::rsa: EVP_PKEY_encrypt (size) failed");
}
std::vector<std::uint8_t> out(outLen);
if (EVP_PKEY_encrypt(ctx.get(), out.data(), &outLen, plaintext.data(), plaintext.size()) != 1) {
throw std::runtime_error("agw::crypto::rsa: EVP_PKEY_encrypt failed");
}
out.resize(outLen);
return out;
}
bool rsaPublicKeyValid(const std::string &publicKeyPem)
{
try {
loadPublicKey(publicKeyPem);
return true;
} catch (...) {
return false;
}
}
std::vector<std::uint8_t> rsaDecryptPrivatePkcs1(const std::vector<std::uint8_t> &ciphertext,
const std::string &privateKeyPem)
{
PkeyPtr key = loadPrivateKey(privateKeyPem);
PkeyCtxPtr ctx(EVP_PKEY_CTX_new(key.get(), nullptr), EVP_PKEY_CTX_free);
if (!ctx) {
throw std::runtime_error("agw::crypto::rsa: EVP_PKEY_CTX_new failed");
}
if (EVP_PKEY_decrypt_init(ctx.get()) != 1) {
throw std::runtime_error("agw::crypto::rsa: EVP_PKEY_decrypt_init failed");
}
if (EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_PADDING) != 1) {
throw std::runtime_error("agw::crypto::rsa: set_rsa_padding failed");
}
std::size_t outLen = 0;
if (EVP_PKEY_decrypt(ctx.get(), nullptr, &outLen, ciphertext.data(), ciphertext.size()) != 1) {
throw std::runtime_error("agw::crypto::rsa: EVP_PKEY_decrypt (size) failed");
}
std::vector<std::uint8_t> out(outLen);
if (EVP_PKEY_decrypt(ctx.get(), out.data(), &outLen, ciphertext.data(), ciphertext.size()) != 1) {
throw std::runtime_error("agw::crypto::rsa: EVP_PKEY_decrypt failed");
}
out.resize(outLen);
return out;
}
}

19
agw-sdk/src/crypto/rsa.h Normal file
View File

@@ -0,0 +1,19 @@
#ifndef AGW_CRYPTO_RSA_H
#define AGW_CRYPTO_RSA_H
#include <cstdint>
#include <string>
#include <vector>
namespace agw::crypto
{
std::vector<std::uint8_t> rsaEncryptPublicPkcs1(const std::vector<std::uint8_t> &plaintext,
const std::string &publicKeyPem);
std::vector<std::uint8_t> rsaDecryptPrivatePkcs1(const std::vector<std::uint8_t> &ciphertext,
const std::string &privateKeyPem);
bool rsaPublicKeyValid(const std::string &publicKeyPem);
}
#endif

View File

@@ -0,0 +1,14 @@
#ifndef AGW_DETAIL_TEST_HOOKS_H
#define AGW_DETAIL_TEST_HOOKS_H
#include <memory>
#include "agw/http.h"
namespace agw::detail
{
void setNextTestHttpClient(std::shared_ptr<IHttpClient> http);
std::shared_ptr<IHttpClient> takeNextTestHttpClient();
}
#endif

View File

@@ -0,0 +1,100 @@
#include "failover/bypass_policy.h"
#include "protocol/keys.h"
#include "util/json.h"
namespace agw::failover
{
namespace
{
constexpr const char *kPattern1 = "No active configuration found for";
constexpr const char *kPattern2 = "No non-revoked public key found for";
constexpr const char *kPattern3 = "Account not found.";
constexpr const char *kPatternQrSessionNotFound = "QR session not found";
constexpr const char *kPatternSessionNotFound = "Session not found";
constexpr const char *kUpdateRequestPattern = "client version update is required";
constexpr const char *kUnprocessableSubscriptionMessage =
"Failed to retrieve subscription information. Is it activated?";
constexpr int kNotFound = 404;
constexpr int kNotImplemented = 501;
constexpr int kPaymentRequired = 402;
constexpr int kConflict = 409;
constexpr int kRequestTimeout = 408;
constexpr int kUnprocessableEntity = 422;
bool contains(const std::string &body, const char *needle)
{
return body.find(needle) != std::string::npos;
}
std::string trim(const std::string &s)
{
std::size_t b = 0, e = s.size();
while (b < e && (s[b] == ' ' || s[b] == '\t' || s[b] == '\n' || s[b] == '\r'))
++b;
while (e > b && (s[e - 1] == ' ' || s[e - 1] == '\t' || s[e - 1] == '\n' || s[e - 1] == '\r'))
--e;
return s.substr(b, e - b);
}
}
bool shouldBypassProxy(TransportError transportError, const std::string &decryptedBody, bool decryptionSuccessful)
{
if (!decryptionSuccessful) {
return true;
}
int apiHttpStatus = -1;
std::string apiErrorMessage;
try {
util::Json obj = util::Json::parse(decryptedBody);
if (obj.is_object()) {
if (auto it = obj.find(protocol::keys::httpStatus); it != obj.end() && it->is_number_integer()) {
apiHttpStatus = it->get<int>();
}
if (auto it = obj.find(protocol::keys::message); it != obj.end() && it->is_string()) {
apiErrorMessage = trim(it->get<std::string>());
}
}
} catch (...) {
}
if (transportError == TransportError::Canceled || transportError == TransportError::Timeout) {
return true;
}
if (contains(decryptedBody, "html")) {
return true;
}
if (apiHttpStatus == kRequestTimeout) {
return false;
}
if (apiHttpStatus == kNotFound) {
if (contains(decryptedBody, kPattern1) || contains(decryptedBody, kPattern2)
|| contains(decryptedBody, kPattern3) || contains(decryptedBody, kPatternQrSessionNotFound)
|| contains(decryptedBody, kPatternSessionNotFound)) {
return false;
}
return true;
}
if (apiHttpStatus == kNotImplemented) {
if (contains(decryptedBody, kUpdateRequestPattern)) {
return false;
}
return true;
}
if (apiHttpStatus == kConflict) {
return false;
}
if (apiHttpStatus == kPaymentRequired) {
return false;
}
if (apiHttpStatus == kUnprocessableEntity) {
return apiErrorMessage != kUnprocessableSubscriptionMessage;
}
if (transportError != TransportError::None) {
return true;
}
return false;
}
}

View File

@@ -0,0 +1,13 @@
#ifndef AGW_FAILOVER_BYPASS_POLICY_H
#define AGW_FAILOVER_BYPASS_POLICY_H
#include <string>
#include "agw/http.h"
namespace agw::failover
{
bool shouldBypassProxy(TransportError transportError, const std::string &decryptedBody, bool decryptionSuccessful);
}
#endif

View File

@@ -0,0 +1,71 @@
#include "failover/proxy_list.h"
#include "crypto/aes.h"
#include "crypto/hash.h"
#include "util/base64.h"
#include "util/json.h"
namespace agw::failover
{
namespace
{
void appendStorageUrls(const std::vector<std::string> &baseUrls, const FailoverContext &ctx,
std::vector<std::string> &target)
{
if (!ctx.serviceType.empty()) {
const std::string token = "endpoints-" + ctx.serviceType + "-" + ctx.userCountryCode;
const std::string encoded =
util::base64UrlEncodeNoPad(std::vector<std::uint8_t>(token.begin(), token.end()));
for (const auto &base : baseUrls) {
target.push_back(base + encoded + ".json");
}
}
for (const auto &base : baseUrls) {
target.push_back(base + "endpoints.json");
}
}
std::vector<std::string> parseEndpointsArray(const std::string &json)
{
std::vector<std::string> out;
try {
util::Json doc = util::Json::parse(json);
if (doc.is_array()) {
for (const auto &el : doc) {
if (el.is_string()) {
out.push_back(el.get<std::string>());
}
}
}
} catch (...) {
}
return out;
}
}
std::vector<std::string> buildStorageUrls(const std::vector<std::string> &primaryBaseUrls,
const std::vector<std::string> &fallbackBaseUrls,
const FailoverContext &ctx)
{
std::vector<std::string> result;
appendStorageUrls(primaryBaseUrls, ctx, result);
appendStorageUrls(fallbackBaseUrls, ctx, result);
return result;
}
std::vector<std::string> decodeProxyList(const std::string &body, bool isDevEnvironment, const std::string &pubKeyPem)
{
if (isDevEnvironment) {
return parseEndpointsArray(body);
}
const std::vector<std::uint8_t> pubBytes(pubKeyPem.begin(), pubKeyPem.end());
const std::string h = crypto::toHex(crypto::sha512(pubBytes));
const std::vector<std::uint8_t> key = crypto::fromHex(h.substr(0, 64));
const std::vector<std::uint8_t> iv = crypto::fromHex(h.substr(64, 32));
const std::vector<std::uint8_t> cipher = util::base64Decode(body);
const std::vector<std::uint8_t> plain = crypto::aesDecryptCbc(cipher, key, iv);
return parseEndpointsArray(std::string(plain.begin(), plain.end()));
}
}

View File

@@ -0,0 +1,19 @@
#ifndef AGW_FAILOVER_PROXY_LIST_H
#define AGW_FAILOVER_PROXY_LIST_H
#include <string>
#include <vector>
#include "agw/types.h"
namespace agw::failover
{
std::vector<std::string> buildStorageUrls(const std::vector<std::string> &primaryBaseUrls,
const std::vector<std::string> &fallbackBaseUrls,
const FailoverContext &ctx);
std::vector<std::string> decodeProxyList(const std::string &body, bool isDevEnvironment,
const std::string &pubKeyPem);
}
#endif

View File

@@ -0,0 +1,21 @@
#include "failover/proxy_picker.h"
namespace agw::failover
{
std::string pickHealthyProxy(IHttpClient &http, const std::vector<std::string> &proxyUrls, int timeoutMsecs)
{
for (const auto &proxy : proxyUrls) {
HttpRequest req;
req.url = proxy + "lmbd-health";
req.method = "GET";
req.headers = { { "Content-Type", "application/json" } };
req.timeoutMsecs = timeoutMsecs;
const HttpResponse resp = http.send(req);
if (resp.error == TransportError::None && !resp.sslError) {
return proxy;
}
}
return { };
}
}

View File

@@ -0,0 +1,14 @@
#ifndef AGW_FAILOVER_PROXY_PICKER_H
#define AGW_FAILOVER_PROXY_PICKER_H
#include <string>
#include <vector>
#include "agw/http.h"
namespace agw::failover
{
std::string pickHealthyProxy(IHttpClient &http, const std::vector<std::string> &proxyUrls, int timeoutMsecs);
}
#endif

View File

@@ -0,0 +1,364 @@
#include "agw/gateway_controller.h"
#include <algorithm>
#include <chrono>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <random>
#include <sstream>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "crypto/rng.h"
#include "failover/bypass_policy.h"
#include "failover/proxy_list.h"
#include "failover/proxy_picker.h"
#include "protocol/error_mapping.h"
#include "protocol/request_builder.h"
#include "protocol/response.h"
#include "util/thread_pool.h"
#include "util/url.h"
#include "util/uuid.h"
namespace agw {
namespace {
bool isCancelled(const CancellationToken *cancel)
{
return cancel != nullptr && cancel->isCancelled();
}
std::function<bool()> makeCancelCheck(CancellationToken *cancel)
{
if (cancel == nullptr) {
return {};
}
return [cancel] { return cancel->isCancelled(); };
}
std::string threadIdStr()
{
std::ostringstream oss;
oss << std::this_thread::get_id();
return oss.str();
}
const char *transportErrorName(TransportError e)
{
switch (e) {
case TransportError::None: return "None";
case TransportError::Timeout: return "Timeout";
case TransportError::Canceled: return "Canceled";
case TransportError::OperationNotImplemented: return "OperationNotImplemented";
case TransportError::ConnectionError: return "ConnectionError";
}
return "?";
}
}
struct GatewayController::Impl {
Config config;
std::shared_ptr<IHttpClient> http;
std::unique_ptr<crypto::IRng> rng;
std::mutex proxyMutex;
std::string cachedProxy;
util::ThreadPool pool;
explicit Impl(Config cfg)
: config(std::move(cfg)),
rng(std::make_unique<crypto::DefaultRng>()),
pool(static_cast<std::size_t>(config.threadPoolSize))
{
http = config.httpClient ? config.httpClient
: std::shared_ptr<IHttpClient>(makeDefaultHttpClient());
log(LogLevel::Info,
"client created: dev=" + std::string(config.isDevEnvironment ? "1" : "0")
+ " timeout=" + std::to_string(config.requestTimeoutMsecs) + "ms"
+ " pool=" + std::to_string(config.threadPoolSize)
+ " s3primary=" + std::to_string(config.s3PrimaryEndpoints.size())
+ " s3fallback=" + std::to_string(config.s3FallbackEndpoints.size())
+ " customHttp=" + std::string(config.httpClient ? "1" : "0"));
}
void log(LogLevel level, const std::string &message) const
{
if (config.log) {
config.log(level, message);
}
}
void dbg(const std::string &message) const { log(LogLevel::Debug, message); }
std::string getCachedProxy()
{
std::lock_guard<std::mutex> lock(proxyMutex);
return cachedProxy;
}
void setCachedProxy(const std::string &proxy)
{
std::lock_guard<std::mutex> lock(proxyMutex);
cachedProxy = proxy;
}
bool attempt(const std::string &endpoint, const std::string &host, const HttpRequest &baseReq,
const std::vector<std::uint8_t> &key, const std::vector<std::uint8_t> &iv,
HttpResponse &resp, protocol::DecryptResult &dec)
{
HttpRequest req = baseReq;
req.url = util::formatEndpoint(endpoint, host);
dbg(" proxy attempt: POST " + req.url);
const auto t0 = std::chrono::steady_clock::now();
resp = http->send(req);
const auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - t0).count();
dec = protocol::tryDecryptResponse(resp.body, key, iv);
const bool bypass = resp.sslError || failover::shouldBypassProxy(resp.error, dec.decryptedBody, dec.ok);
dbg(" proxy attempt result: transport=" + std::string(transportErrorName(resp.error))
+ " ssl=" + std::string(resp.sslError ? "1" : "0") + " http=" + std::to_string(resp.httpStatusCode)
+ " bodyLen=" + std::to_string(resp.body.size()) + " decryptOk=" + std::string(dec.ok ? "1" : "0")
+ " bypassAgain=" + std::string(bypass ? "1" : "0") + " (" + std::to_string(ms) + "ms)");
return !bypass;
}
void runFailover(const std::string &endpoint, const HttpRequest &baseReq, const FailoverContext &ctx,
const std::vector<std::uint8_t> &key, const std::vector<std::uint8_t> &iv,
HttpResponse &resp, protocol::DecryptResult &dec, CancellationToken *cancel)
{
if (isCancelled(cancel)) {
dbg("failover: cancelled before start");
return;
}
std::random_device rd;
std::mt19937 gen(rd());
std::vector<std::string> primary = config.s3PrimaryEndpoints;
std::vector<std::string> fallback = config.s3FallbackEndpoints;
std::shuffle(primary.begin(), primary.end(), gen);
std::shuffle(fallback.begin(), fallback.end(), gen);
const std::vector<std::string> storageUrls = failover::buildStorageUrls(primary, fallback, ctx);
dbg("failover: storage urls=" + std::to_string(storageUrls.size())
+ " service='" + ctx.serviceType + "' country='" + ctx.userCountryCode + "'");
std::vector<std::string> proxyUrls;
for (const auto &storageUrl : storageUrls) {
if (isCancelled(cancel)) {
dbg("failover: cancelled during storage fetch");
return;
}
HttpRequest g;
g.url = storageUrl;
g.method = "GET";
g.headers = {{"Content-Type", "application/json"}};
g.timeoutMsecs = config.proxyStorageTimeoutMsecs;
g.cancelCheck = makeCancelCheck(cancel);
const HttpResponse gr = http->send(g);
dbg(" storage GET " + storageUrl + " → transport=" + std::string(transportErrorName(gr.error))
+ " ssl=" + std::string(gr.sslError ? "1" : "0") + " http=" + std::to_string(gr.httpStatusCode)
+ " bodyLen=" + std::to_string(gr.body.size()));
if (gr.error != TransportError::None || gr.sslError) {
continue;
}
try {
proxyUrls = failover::decodeProxyList(gr.body, config.isDevEnvironment, config.agwPublicKeyPem);
dbg(" decoded proxy list: " + std::to_string(proxyUrls.size()) + " proxies");
break;
} catch (...) {
dbg(" proxy list decode failed → next storage");
continue;
}
}
std::shuffle(proxyUrls.begin(), proxyUrls.end(), gen);
std::string proxy = getCachedProxy();
if (proxy.empty()) {
if (isCancelled(cancel)) {
dbg("failover: cancelled before health-check");
return;
}
dbg("failover: no cached proxy → health-check of " + std::to_string(proxyUrls.size()) + " proxies");
proxy = failover::pickHealthyProxy(*http, proxyUrls, config.proxyHealthTimeoutMsecs);
if (!proxy.empty()) {
dbg("failover: healthy proxy = " + proxy + " (cached)");
setCachedProxy(proxy);
} else {
dbg("failover: no healthy proxy found");
}
} else {
dbg("failover: using cached proxy = " + proxy);
}
if (!proxy.empty()) {
if (isCancelled(cancel)) {
return;
}
if (attempt(endpoint, proxy, baseReq, key, iv, resp, dec)) {
dbg("failover: succeeded via cached/first proxy");
return;
}
}
for (const auto &p : proxyUrls) {
if (isCancelled(cancel)) {
return;
}
if (attempt(endpoint, p, baseReq, key, iv, resp, dec)) {
dbg("failover: succeeded via proxy " + p + " (cached)");
setCachedProxy(p);
return;
}
}
dbg("failover: exhausted all proxies (using last attempt result)");
}
Response executePost(const std::string &endpoint, const std::string &payload,
const FailoverContext &ctx, CancellationToken *cancel)
{
const auto tStart = std::chrono::steady_clock::now();
log(LogLevel::Info, "post START endpoint='" + endpoint + "' service='" + ctx.serviceType
+ "' country='" + ctx.userCountryCode + "' payloadLen=" + std::to_string(payload.size())
+ " thread=" + threadIdStr());
if (isCancelled(cancel)) {
log(LogLevel::Info, "post: cancelled before start");
return Response{ErrorCode::Cancelled, std::string()};
}
protocol::EncryptedRequest enc =
protocol::buildEncryptedRequest(payload, config.agwPublicKeyPem, *rng);
if (enc.error != ErrorCode::NoError) {
log(LogLevel::Warning, "post: request build failed error="
+ std::to_string(static_cast<int>(enc.error)));
return Response{enc.error, std::string()};
}
dbg("request built: bodyLen=" + std::to_string(enc.body.size()) + " (key/iv/salt generated)");
if (isCancelled(cancel)) {
return Response{ErrorCode::Cancelled, std::string()};
}
const std::string requestId = util::makeUuidV4(*rng);
const std::string cached = getCachedProxy();
const std::string directHost = cached.empty() ? config.gatewayEndpoint : cached;
const std::string url = util::formatEndpoint(endpoint, directHost);
dbg("direct request: url=" + url + " reqId=" + requestId
+ " viaCachedProxy=" + std::string(cached.empty() ? "0" : "1"));
if (config.onBeforeRequest) {
const std::string host = util::extractHost(url);
dbg("onBeforeRequest(host=" + host + ")");
config.onBeforeRequest(host);
}
HttpRequest req;
req.url = url;
req.method = "POST";
req.body = enc.body;
req.headers = {
{"Content-Type", "application/json"},
{"X-Client-Request-ID", requestId},
};
req.timeoutMsecs = config.requestTimeoutMsecs;
req.cancelCheck = makeCancelCheck(cancel);
const auto t0 = std::chrono::steady_clock::now();
HttpResponse resp = http->send(req);
const auto httpMs = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - t0).count();
if (isCancelled(cancel)) {
log(LogLevel::Info, "post: cancelled after direct send");
return Response{ErrorCode::Cancelled, std::string()};
}
protocol::DecryptResult dec = protocol::tryDecryptResponse(resp.body, enc.key, enc.iv);
dbg("direct response: transport=" + std::string(transportErrorName(resp.error))
+ " ssl=" + std::string(resp.sslError ? "1" : "0") + " http=" + std::to_string(resp.httpStatusCode)
+ " bodyLen=" + std::to_string(resp.body.size()) + " decryptOk=" + std::string(dec.ok ? "1" : "0")
+ " (" + std::to_string(httpMs) + "ms)");
const bool bypass = !resp.sslError
&& failover::shouldBypassProxy(resp.error, dec.decryptedBody, dec.ok);
if (bypass) {
log(LogLevel::Info, "direct response suspicious — running failover");
runFailover(endpoint, req, ctx, enc.key, enc.iv, resp, dec, cancel);
if (isCancelled(cancel)) {
log(LogLevel::Info, "post: cancelled during failover");
return Response{ErrorCode::Cancelled, std::string()};
}
} else {
dbg("direct response accepted (no failover)");
}
Response out;
out.body = dec.decryptedBody;
const ErrorCode mapped = protocol::mapResponseError(resp.sslError, resp.error, dec.decryptedBody);
const auto totalMs = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - tStart).count();
if (mapped != ErrorCode::NoError) {
out.error = mapped;
log(LogLevel::Warning, "post DONE error=" + std::to_string(static_cast<int>(mapped))
+ " bodyLen=" + std::to_string(out.body.size()) + " (" + std::to_string(totalMs) + "ms)");
return out;
}
if (!dec.ok) {
out.error = ErrorCode::ApiConfigDecryptionError;
log(LogLevel::Error, "post DONE: response decryption failed (1106) ("
+ std::to_string(totalMs) + "ms)");
return out;
}
out.error = ErrorCode::NoError;
log(LogLevel::Info, "post DONE ok bodyLen=" + std::to_string(out.body.size())
+ " (" + std::to_string(totalMs) + "ms)");
return out;
}
};
GatewayController::GatewayController(Config config) : m_impl(std::make_unique<Impl>(std::move(config))) {}
GatewayController::~GatewayController() = default;
GatewayController::GatewayController(GatewayController &&) noexcept = default;
GatewayController &GatewayController::operator=(GatewayController &&) noexcept = default;
Response GatewayController::post(const std::string &endpoint, const std::string &payload,
const FailoverContext &ctx, CancellationToken *cancel)
{
return m_impl->executePost(endpoint, payload, ctx, cancel);
}
void GatewayController::postAsync(const std::string &endpoint, const std::string &payload,
std::function<void(Response)> onResult, const FailoverContext &ctx,
CancellationToken *cancel)
{
Impl *impl = m_impl.get();
impl->dbg("postAsync: submitting to pool (caller thread=" + threadIdStr() + ")");
impl->pool.submit([impl, endpoint, payload, onResult = std::move(onResult), ctx, cancel]() {
impl->dbg("postAsync: running on pool thread=" + threadIdStr());
Response r = impl->executePost(endpoint, payload, ctx, cancel);
if (onResult) {
onResult(std::move(r));
}
});
}
std::future<Response> GatewayController::postFuture(const std::string &endpoint, const std::string &payload,
const FailoverContext &ctx, CancellationToken *cancel)
{
auto promise = std::make_shared<std::promise<Response>>();
std::future<Response> fut = promise->get_future();
Impl *impl = m_impl.get();
impl->dbg("postFuture: submitting to pool (caller thread=" + threadIdStr() + ")");
impl->pool.submit([impl, endpoint, payload, ctx, cancel, promise]() {
impl->dbg("postFuture: running on pool thread=" + threadIdStr());
promise->set_value(impl->executePost(endpoint, payload, ctx, cancel));
});
return fut;
}
}

View File

@@ -0,0 +1,132 @@
#ifdef AGW_HAVE_CURL
#include "http/curl_client.h"
#include <mutex>
#include <string>
#include <curl/curl.h>
namespace agw
{
namespace
{
std::once_flag g_curlInitOnce;
void ensureCurlGlobalInit()
{
std::call_once(g_curlInitOnce, []() { curl_global_init(CURL_GLOBAL_DEFAULT); });
}
std::size_t writeCallback(char *ptr, std::size_t size, std::size_t nmemb, void *userdata)
{
const std::size_t total = size * nmemb;
auto *buf = static_cast<std::string *>(userdata);
buf->append(ptr, total);
return total;
}
int xferCallback(void *clientp, curl_off_t, curl_off_t, curl_off_t, curl_off_t)
{
auto *check = static_cast<const std::function<bool()> *>(clientp);
if (check && *check && (*check)()) {
return 1;
}
return 0;
}
TransportError mapCurlError(CURLcode code, bool &sslError)
{
sslError = false;
switch (code) {
case CURLE_OK: return TransportError::None;
case CURLE_OPERATION_TIMEDOUT: return TransportError::Timeout;
case CURLE_ABORTED_BY_CALLBACK: return TransportError::Canceled;
case CURLE_SSL_CONNECT_ERROR:
case CURLE_PEER_FAILED_VERIFICATION:
case CURLE_SSL_CERTPROBLEM:
case CURLE_SSL_CIPHER:
case CURLE_SSL_CACERT_BADFILE:
case CURLE_SSL_ISSUER_ERROR: sslError = true; return TransportError::ConnectionError;
default: return TransportError::ConnectionError;
}
}
}
CurlHttpClient::CurlHttpClient()
{
ensureCurlGlobalInit();
}
CurlHttpClient::~CurlHttpClient() = default;
HttpResponse CurlHttpClient::send(const HttpRequest &request)
{
HttpResponse response;
CURL *curl = curl_easy_init();
if (!curl) {
response.error = TransportError::ConnectionError;
response.errorString = "curl_easy_init failed";
return response;
}
struct curl_slist *headers = nullptr;
for (const auto &h : request.headers) {
const std::string line = h.first + ": " + h.second;
headers = curl_slist_append(headers, line.c_str());
}
curl_easy_setopt(curl, CURLOPT_URL, request.url.c_str());
if (headers) {
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
}
if (request.timeoutMsecs > 0) {
curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, static_cast<long>(request.timeoutMsecs));
}
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeCallback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response.body);
curl_easy_setopt(curl, CURLOPT_NOSIGNAL, 1L);
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "");
if (request.cancelCheck) {
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, xferCallback);
curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &request.cancelCheck);
}
if (request.method == "POST") {
curl_easy_setopt(curl, CURLOPT_POST, 1L);
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request.body.data());
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, static_cast<long>(request.body.size()));
} else {
curl_easy_setopt(curl, CURLOPT_HTTPGET, 1L);
}
const CURLcode code = curl_easy_perform(curl);
bool sslError = false;
response.error = mapCurlError(code, sslError);
response.sslError = sslError;
if (code != CURLE_OK) {
response.errorString = curl_easy_strerror(code);
}
long httpCode = 0;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &httpCode);
response.httpStatusCode = static_cast<int>(httpCode);
if (headers) {
curl_slist_free_all(headers);
}
curl_easy_cleanup(curl);
return response;
}
std::unique_ptr<IHttpClient> makeDefaultHttpClient()
{
return std::make_unique<CurlHttpClient>();
}
}
#endif

View File

@@ -0,0 +1,17 @@
#ifndef AGW_HTTP_CURL_CLIENT_H
#define AGW_HTTP_CURL_CLIENT_H
#include "agw/http.h"
namespace agw
{
class CurlHttpClient : public IHttpClient
{
public:
CurlHttpClient();
~CurlHttpClient() override;
HttpResponse send(const HttpRequest &request) override;
};
}
#endif

View File

@@ -0,0 +1,15 @@
#ifndef AGW_HAVE_CURL
#include <stdexcept>
#include "agw/http.h"
namespace agw
{
std::unique_ptr<IHttpClient> makeDefaultHttpClient()
{
throw std::runtime_error("agw: SDK built without libcurl; provide Config::httpClient (your own IHttpClient)");
}
}
#endif

View File

@@ -0,0 +1,133 @@
#include "protocol/error_mapping.h"
#include <algorithm>
#include <cctype>
#include "protocol/keys.h"
#include "util/json.h"
namespace agw::protocol
{
namespace
{
constexpr int kConflict = 409;
constexpr int kNotFound = 404;
constexpr int kNotImplemented = 501;
constexpr int kPaymentRequired = 402;
constexpr int kTooManyRequests = 429;
constexpr int kRequestTimeout = 408;
constexpr int kUnprocessableEntity = 422;
constexpr const char *kUnprocessableSubscriptionMessage =
"Failed to retrieve subscription information. Is it activated?";
constexpr const char *kTrialAlreadyUsedMessage = "trial subscription already used";
std::string trim(const std::string &s)
{
std::size_t b = 0;
std::size_t e = s.size();
while (b < e && std::isspace(static_cast<unsigned char>(s[b])))
++b;
while (e > b && std::isspace(static_cast<unsigned char>(s[e - 1])))
--e;
return s.substr(b, e - b);
}
std::string toLower(std::string s)
{
std::transform(s.begin(), s.end(), s.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
return s;
}
bool containsCI(const std::string &haystack, const std::string &needle)
{
return toLower(haystack).find(toLower(needle)) != std::string::npos;
}
std::string messageFrom(const util::Json &obj)
{
auto it = obj.find(keys::message);
if (it != obj.end() && it->is_string()) {
return trim(it->get<std::string>());
}
return { };
}
}
ErrorCode mapResponseError(bool sslError, TransportError transportError, const std::string &decryptedBody)
{
if (sslError) {
return ErrorCode::ApiConfigSslError;
}
if (transportError == TransportError::Timeout || transportError == TransportError::Canceled) {
return ErrorCode::ApiConfigTimeoutError;
}
if (transportError == TransportError::OperationNotImplemented) {
return ErrorCode::ApiUpdateRequestError;
}
util::Json obj;
bool isObject = false;
try {
obj = util::Json::parse(decryptedBody);
isObject = obj.is_object();
} catch (...) {
isObject = false;
}
if (isObject) {
int httpStatus = -1;
if (auto it = obj.find(keys::httpStatus); it != obj.end() && it->is_number_integer()) {
httpStatus = it->get<int>();
}
const std::string message = messageFrom(obj);
if (httpStatus == kTooManyRequests) {
return ErrorCode::ApiRateLimitError;
}
if (httpStatus == kConflict) {
if (containsCI(message, kTrialAlreadyUsedMessage)) {
return ErrorCode::ApiTrialAlreadyUsedError;
}
return ErrorCode::ApiConfigLimitError;
}
if (httpStatus == kNotFound) {
return ErrorCode::ApiNotFoundError;
}
if (httpStatus == kRequestTimeout) {
return ErrorCode::ApiConfigTimeoutError;
}
if (httpStatus == kNotImplemented) {
return ErrorCode::ApiUpdateRequestError;
}
if (httpStatus == kUnprocessableEntity) {
if (message == kUnprocessableSubscriptionMessage) {
return ErrorCode::ApiSubscriptionExpiredError;
}
return ErrorCode::ApiConfigDownloadError;
}
if (httpStatus == kPaymentRequired) {
if (containsCI(message, "refresh_captcha")) {
return ErrorCode::ApiCaptchaRefreshError;
}
if (containsCI(message, "invalid_captcha")) {
return ErrorCode::ApiCaptchaInvalidError;
}
if (obj.contains("captcha_id") || obj.contains("captcha_image")
|| containsCI(message, "rate_limit_exceeded")) {
return ErrorCode::ApiCaptchaRequiredError;
}
return ErrorCode::ApiSubscriptionNotActiveError;
}
if (httpStatus >= 300) {
return ErrorCode::ApiConfigDownloadError;
}
}
if (transportError == TransportError::None) {
return ErrorCode::NoError;
}
return ErrorCode::ApiConfigDownloadError;
}
}

View File

@@ -0,0 +1,14 @@
#ifndef AGW_PROTOCOL_ERROR_MAPPING_H
#define AGW_PROTOCOL_ERROR_MAPPING_H
#include <string>
#include "agw/http.h"
#include "agw/types.h"
namespace agw::protocol
{
ErrorCode mapResponseError(bool sslError, TransportError transportError, const std::string &decryptedBody);
}
#endif

View File

@@ -0,0 +1,18 @@
#ifndef AGW_PROTOCOL_KEYS_H
#define AGW_PROTOCOL_KEYS_H
namespace agw::protocol::keys {
inline constexpr const char *aesKey = "aes_key";
inline constexpr const char *aesIv = "aes_iv";
inline constexpr const char *aesSalt = "aes_salt";
inline constexpr const char *apiPayload = "api_payload";
inline constexpr const char *keyPayload = "key_payload";
inline constexpr const char *serviceType = "service_type";
inline constexpr const char *userCountryCode = "user_country_code";
inline constexpr const char *httpStatus = "http_status";
inline constexpr const char *message = "message";
}
#endif

View File

@@ -0,0 +1,55 @@
#include "protocol/request_builder.h"
#include "crypto/aes.h"
#include "crypto/rsa.h"
#include "protocol/keys.h"
#include "util/base64.h"
#include "util/json.h"
namespace agw::protocol
{
namespace
{
std::vector<std::uint8_t> bytesOf(const std::string &s)
{
return std::vector<std::uint8_t>(s.begin(), s.end());
}
}
EncryptedRequest buildEncryptedRequest(const std::string &payload, const std::string &publicKeyPem, crypto::IRng &rng)
{
namespace k = keys;
EncryptedRequest out;
out.key = rng.bytes(32);
out.iv = rng.bytes(32);
out.salt = rng.bytes(8);
if (!crypto::rsaPublicKeyValid(publicKeyPem)) {
out.error = ErrorCode::ApiMissingAgwPublicKey;
return out;
}
util::Json keysJson;
keysJson[k::aesKey] = util::base64Encode(out.key);
keysJson[k::aesIv] = util::base64Encode(out.iv);
keysJson[k::aesSalt] = util::base64Encode(out.salt);
const std::string keysSerialized = util::qtIndentedDump(keysJson);
std::string keyPayloadB64;
std::string apiPayloadB64;
try {
keyPayloadB64 = util::base64Encode(crypto::rsaEncryptPublicPkcs1(bytesOf(keysSerialized), publicKeyPem));
apiPayloadB64 = util::base64Encode(crypto::aesEncryptCbc(bytesOf(payload), out.key, out.iv));
} catch (...) {
out.error = ErrorCode::ApiConfigDecryptionError;
return out;
}
util::Json body;
body[k::keyPayload] = keyPayloadB64;
body[k::apiPayload] = apiPayloadB64;
out.body = util::qtIndentedDump(body);
return out;
}
}

View File

@@ -0,0 +1,26 @@
#ifndef AGW_PROTOCOL_REQUEST_BUILDER_H
#define AGW_PROTOCOL_REQUEST_BUILDER_H
#include <cstdint>
#include <string>
#include <vector>
#include "agw/types.h"
#include "crypto/rng.h"
namespace agw::protocol
{
struct EncryptedRequest
{
std::string body;
std::vector<std::uint8_t> key;
std::vector<std::uint8_t> iv;
std::vector<std::uint8_t> salt;
ErrorCode error = ErrorCode::NoError;
};
EncryptedRequest buildEncryptedRequest(const std::string &payload, const std::string &publicKeyPem,
crypto::IRng &rng);
}
#endif

View File

@@ -0,0 +1,24 @@
#include "protocol/response.h"
#include "crypto/aes.h"
namespace agw::protocol
{
DecryptResult tryDecryptResponse(const std::string &encrypted, const std::vector<std::uint8_t> &key,
const std::vector<std::uint8_t> &iv)
{
DecryptResult result;
result.decryptedBody = encrypted;
result.ok = false;
try {
const std::vector<std::uint8_t> in(encrypted.begin(), encrypted.end());
const std::vector<std::uint8_t> out = crypto::aesDecryptCbc(in, key, iv);
result.decryptedBody.assign(out.begin(), out.end());
result.ok = true;
} catch (...) {
result.decryptedBody = encrypted;
result.ok = false;
}
return result;
}
}

View File

@@ -0,0 +1,20 @@
#ifndef AGW_PROTOCOL_RESPONSE_H
#define AGW_PROTOCOL_RESPONSE_H
#include <cstdint>
#include <string>
#include <vector>
namespace agw::protocol
{
struct DecryptResult
{
std::string decryptedBody;
bool ok = false;
};
DecryptResult tryDecryptResponse(const std::string &encrypted, const std::vector<std::uint8_t> &key,
const std::vector<std::uint8_t> &iv);
}
#endif

108
agw-sdk/src/util/base64.cpp Normal file
View File

@@ -0,0 +1,108 @@
#include "base64.h"
#include <array>
namespace agw::util
{
namespace
{
const char *kStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
const char *kUrl = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
std::string encode(const std::vector<std::uint8_t> &data, const char *alphabet, bool pad)
{
std::string out;
out.reserve((data.size() + 2) / 3 * 4);
std::size_t i = 0;
const std::size_t n = data.size();
while (i + 3 <= n) {
const std::uint32_t v = (std::uint32_t(data[i]) << 16) | (std::uint32_t(data[i + 1]) << 8) | data[i + 2];
out.push_back(alphabet[(v >> 18) & 0x3F]);
out.push_back(alphabet[(v >> 12) & 0x3F]);
out.push_back(alphabet[(v >> 6) & 0x3F]);
out.push_back(alphabet[v & 0x3F]);
i += 3;
}
const std::size_t rem = n - i;
if (rem == 1) {
const std::uint32_t v = std::uint32_t(data[i]) << 16;
out.push_back(alphabet[(v >> 18) & 0x3F]);
out.push_back(alphabet[(v >> 12) & 0x3F]);
if (pad) {
out.push_back('=');
out.push_back('=');
}
} else if (rem == 2) {
const std::uint32_t v = (std::uint32_t(data[i]) << 16) | (std::uint32_t(data[i + 1]) << 8);
out.push_back(alphabet[(v >> 18) & 0x3F]);
out.push_back(alphabet[(v >> 12) & 0x3F]);
out.push_back(alphabet[(v >> 6) & 0x3F]);
if (pad) {
out.push_back('=');
}
}
return out;
}
int decodeChar(char c)
{
if (c >= 'A' && c <= 'Z')
return c - 'A';
if (c >= 'a' && c <= 'z')
return c - 'a' + 26;
if (c >= '0' && c <= '9')
return c - '0' + 52;
if (c == '+' || c == '-')
return 62;
if (c == '/' || c == '_')
return 63;
return -1;
}
}
std::string base64Encode(const std::vector<std::uint8_t> &data)
{
return encode(data, kStd, true);
}
std::string base64UrlEncodeNoPad(const std::vector<std::uint8_t> &data)
{
return encode(data, kUrl, false);
}
std::string base64Encode(const std::string &data)
{
return base64Encode(std::vector<std::uint8_t>(data.begin(), data.end()));
}
std::vector<std::uint8_t> base64Decode(const std::string &text)
{
std::vector<std::uint8_t> out;
out.reserve(text.size() / 4 * 3 + 3);
std::array<int, 4> quad { };
int count = 0;
for (char c : text) {
const int v = decodeChar(c);
if (v < 0) {
continue;
}
quad[count++] = v;
if (count == 4) {
out.push_back(static_cast<std::uint8_t>((quad[0] << 2) | (quad[1] >> 4)));
out.push_back(static_cast<std::uint8_t>((quad[1] << 4) | (quad[2] >> 2)));
out.push_back(static_cast<std::uint8_t>((quad[2] << 6) | quad[3]));
count = 0;
}
}
if (count == 2) {
out.push_back(static_cast<std::uint8_t>((quad[0] << 2) | (quad[1] >> 4)));
} else if (count == 3) {
out.push_back(static_cast<std::uint8_t>((quad[0] << 2) | (quad[1] >> 4)));
out.push_back(static_cast<std::uint8_t>((quad[1] << 4) | (quad[2] >> 2)));
}
return out;
}
}

19
agw-sdk/src/util/base64.h Normal file
View File

@@ -0,0 +1,19 @@
#ifndef AGW_UTIL_BASE64_H
#define AGW_UTIL_BASE64_H
#include <cstdint>
#include <string>
#include <vector>
namespace agw::util
{
std::string base64Encode(const std::vector<std::uint8_t> &data);
std::string base64UrlEncodeNoPad(const std::vector<std::uint8_t> &data);
std::vector<std::uint8_t> base64Decode(const std::string &text);
std::string base64Encode(const std::string &data);
}
#endif

108
agw-sdk/src/util/json.cpp Normal file
View File

@@ -0,0 +1,108 @@
#include "json.h"
#include <cstdint>
namespace agw::util
{
namespace
{
const char *kHex = "0123456789abcdef";
void appendEscaped(std::string &out, const std::string &s)
{
out.push_back('"');
for (unsigned char c : s) {
switch (c) {
case '"': out += "\\\""; break;
case '\\': out += "\\\\"; break;
case '\b': out += "\\b"; break;
case '\f': out += "\\f"; break;
case '\n': out += "\\n"; break;
case '\r': out += "\\r"; break;
case '\t': out += "\\t"; break;
default:
if (c < 0x20) {
out += "\\u00";
out.push_back(kHex[c >> 4]);
out.push_back(kHex[c & 0x0F]);
} else {
out.push_back(static_cast<char>(c));
}
}
}
out.push_back('"');
}
void appendIndent(std::string &out, int level)
{
out.append(static_cast<std::size_t>(level) * 4, ' ');
}
void dumpValue(std::string &out, const Json &j, int indent);
void dumpObject(std::string &out, const Json &j, int indent)
{
out += "{\n";
const int inner = indent + 1;
std::size_t i = 0;
const std::size_t n = j.size();
for (auto it = j.begin(); it != j.end(); ++it, ++i) {
appendIndent(out, inner);
appendEscaped(out, it.key());
out += ": ";
dumpValue(out, it.value(), inner);
if (i + 1 < n) {
out.push_back(',');
}
out.push_back('\n');
}
appendIndent(out, indent);
out.push_back('}');
}
void dumpArray(std::string &out, const Json &j, int indent)
{
out += "[\n";
const int inner = indent + 1;
std::size_t i = 0;
const std::size_t n = j.size();
for (const auto &el : j) {
appendIndent(out, inner);
dumpValue(out, el, inner);
if (i + 1 < n) {
out.push_back(',');
}
out.push_back('\n');
++i;
}
appendIndent(out, indent);
out.push_back(']');
}
void dumpValue(std::string &out, const Json &j, int indent)
{
switch (j.type()) {
case Json::value_t::object: dumpObject(out, j, indent); break;
case Json::value_t::array: dumpArray(out, j, indent); break;
case Json::value_t::string: appendEscaped(out, j.get<std::string>()); break;
case Json::value_t::boolean: out += j.get<bool>() ? "true" : "false"; break;
case Json::value_t::null: out += "null"; break;
case Json::value_t::number_integer:
case Json::value_t::number_unsigned:
case Json::value_t::number_float:
default:
out += j.dump();
break;
}
}
}
std::string qtIndentedDump(const Json &j)
{
std::string out;
dumpValue(out, j, 0);
out.push_back('\n');
return out;
}
}

15
agw-sdk/src/util/json.h Normal file
View File

@@ -0,0 +1,15 @@
#ifndef AGW_UTIL_JSON_H
#define AGW_UTIL_JSON_H
#include <string>
#include <nlohmann/json.hpp>
namespace agw::util
{
using Json = nlohmann::json;
std::string qtIndentedDump(const Json &j);
}
#endif

View File

@@ -0,0 +1,59 @@
#include "util/thread_pool.h"
namespace agw::util
{
ThreadPool::ThreadPool(std::size_t threadCount)
{
if (threadCount == 0) {
threadCount = 1;
}
m_workers.reserve(threadCount);
for (std::size_t i = 0; i < threadCount; ++i) {
m_workers.emplace_back([this] { workerLoop(); });
}
}
ThreadPool::~ThreadPool()
{
{
std::lock_guard<std::mutex> lock(m_mutex);
m_stopping = true;
}
m_cv.notify_all();
for (auto &w : m_workers) {
if (w.joinable()) {
w.join();
}
}
}
void ThreadPool::submit(std::function<void()> task)
{
{
std::lock_guard<std::mutex> lock(m_mutex);
m_tasks.push(std::move(task));
}
m_cv.notify_one();
}
void ThreadPool::workerLoop()
{
for (;;) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(m_mutex);
m_cv.wait(lock, [this] { return m_stopping || !m_tasks.empty(); });
if (m_tasks.empty()) {
if (m_stopping) {
return;
}
continue;
}
task = std::move(m_tasks.front());
m_tasks.pop();
}
task();
}
}
}

View File

@@ -0,0 +1,36 @@
#ifndef AGW_UTIL_THREAD_POOL_H
#define AGW_UTIL_THREAD_POOL_H
#include <condition_variable>
#include <cstddef>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>
namespace agw::util
{
class ThreadPool
{
public:
explicit ThreadPool(std::size_t threadCount);
~ThreadPool();
ThreadPool(const ThreadPool &) = delete;
ThreadPool &operator=(const ThreadPool &) = delete;
void submit(std::function<void()> task);
private:
void workerLoop();
std::vector<std::thread> m_workers;
std::queue<std::function<void()>> m_tasks;
std::mutex m_mutex;
std::condition_variable m_cv;
bool m_stopping = false;
};
}
#endif

48
agw-sdk/src/util/url.cpp Normal file
View File

@@ -0,0 +1,48 @@
#include "util/url.h"
namespace agw::util
{
std::string formatEndpoint(const std::string &endpoint, const std::string &host)
{
const std::string token = "%1";
const std::size_t pos = endpoint.find(token);
if (pos == std::string::npos) {
return endpoint;
}
std::string out = endpoint;
out.replace(pos, token.size(), host);
return out;
}
std::string extractHost(const std::string &url)
{
std::size_t start = 0;
const std::size_t scheme = url.find("://");
if (scheme != std::string::npos) {
start = scheme + 3;
}
std::size_t end = url.size();
for (std::size_t i = start; i < url.size(); ++i) {
const char c = url[i];
if (c == '/' || c == '?' || c == '#') {
end = i;
break;
}
}
std::string authority = url.substr(start, end - start);
const std::size_t at = authority.find('@');
if (at != std::string::npos) {
authority = authority.substr(at + 1);
}
const std::size_t colon = authority.find(':');
if (colon != std::string::npos) {
authority = authority.substr(0, colon);
}
return authority;
}
}

13
agw-sdk/src/util/url.h Normal file
View File

@@ -0,0 +1,13 @@
#ifndef AGW_UTIL_URL_H
#define AGW_UTIL_URL_H
#include <string>
namespace agw::util
{
std::string formatEndpoint(const std::string &endpoint, const std::string &host);
std::string extractHost(const std::string &url);
}
#endif

36
agw-sdk/src/util/uuid.cpp Normal file
View File

@@ -0,0 +1,36 @@
#include "uuid.h"
#include <cstdint>
#include <vector>
namespace agw::util
{
namespace
{
const char *kHex = "0123456789abcdef";
void appendHex(std::string &out, std::uint8_t b)
{
out.push_back(kHex[b >> 4]);
out.push_back(kHex[b & 0x0F]);
}
}
std::string makeUuidV4(crypto::IRng &rng)
{
std::vector<std::uint8_t> b = rng.bytes(16);
b[6] = static_cast<std::uint8_t>((b[6] & 0x0F) | 0x40);
b[8] = static_cast<std::uint8_t>((b[8] & 0x3F) | 0x80);
std::string out;
out.reserve(36);
for (int i = 0; i < 16; ++i) {
if (i == 4 || i == 6 || i == 8 || i == 10) {
out.push_back('-');
}
appendHex(out, b[i]);
}
return out;
}
}

13
agw-sdk/src/util/uuid.h Normal file
View File

@@ -0,0 +1,13 @@
#ifndef AGW_UTIL_UUID_H
#define AGW_UTIL_UUID_H
#include <string>
#include "crypto/rng.h"
namespace agw::util
{
std::string makeUuidV4(crypto::IRng &rng);
}
#endif

View File

@@ -0,0 +1,27 @@
# Тесты agw-sdk. Локально — на встроенном harness (agw_test.h), без внешнего фреймворка.
set(AGW_FIXTURES_DIR "${CMAKE_CURRENT_SOURCE_DIR}/golden/fixtures")
function(agw_add_test name src)
add_executable(${name} ${src})
target_include_directories(${name} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_SOURCE_DIR}/include
${CMAKE_SOURCE_DIR}/src
)
target_link_libraries(${name} PRIVATE agw OpenSSL::Crypto nlohmann_json::nlohmann_json Threads::Threads)
target_compile_definitions(${name} PRIVATE AGW_FIXTURES_DIR="${AGW_FIXTURES_DIR}")
add_test(NAME ${name} COMMAND ${name})
endfunction()
agw_add_test(test_crypto unit/test_crypto.cpp)
agw_add_test(test_json unit/test_json.cpp)
agw_add_test(test_golden golden/test_golden.cpp)
agw_add_test(test_error_mapping unit/test_error_mapping.cpp)
agw_add_test(test_bypass_policy unit/test_bypass_policy.cpp)
agw_add_test(test_proxy_list unit/test_proxy_list.cpp)
agw_add_test(test_thread_pool unit/test_thread_pool.cpp)
agw_add_test(test_post integration/test_post.cpp)
agw_add_test(test_failover integration/test_failover.cpp)
agw_add_test(test_async integration/test_async.cpp)
agw_add_test(test_c_abi integration/test_c_abi.cpp)

39
agw-sdk/tests/agw_test.h Normal file
View File

@@ -0,0 +1,39 @@
#ifndef AGW_TEST_H
#define AGW_TEST_H
#include <cstdio>
#include <cstdlib>
#include <string>
namespace agw_test {
inline int &failCount()
{
static int n = 0;
return n;
}
inline void report(bool ok, const char *expr, const char *file, int line)
{
if (!ok) {
std::fprintf(stderr, "FAIL: %s\n at %s:%d\n", expr, file, line);
++failCount();
}
}
inline void reportEq(const std::string &a, const std::string &b, const char *expr, const char *file, int line)
{
if (a != b) {
std::fprintf(stderr, "FAIL: %s\n at %s:%d\n lhs=[%s]\n rhs=[%s]\n",
expr, file, line, a.c_str(), b.c_str());
++failCount();
}
}
}
#define CHECK(expr) ::agw_test::report((expr), #expr, __FILE__, __LINE__)
#define CHECK_EQ(a, b) ::agw_test::reportEq((a), (b), #a " == " #b, __FILE__, __LINE__)
#define AGW_TEST_MAIN_RETURN() \
(::agw_test::failCount() == 0 ? (std::printf("OK\n"), 0) : (std::fprintf(stderr, "%d check(s) failed\n", ::agw_test::failCount()), 1))
#endif

View File

@@ -0,0 +1 @@
ыaГцЖ7И~▐>ъl╛▌Ц╗ЦaJ╩┤ъ%фTхO╒╢^

View File

@@ -0,0 +1 @@
{"hello":"world"}

View File

@@ -0,0 +1,4 @@
AES_KEY_HEX=000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f
AES_IV_HEX=101112131415161718191a1b1c1d1e1f
AES_PLAINTEXT={"hello":"world"}
AES_CIPHER_B64=2WHnAcP2N+l+jz7fbKyO46jjYUq7h98lxlTIT6K0Xg8=

View File

@@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDL3pxE7uI3RacQ
jvyFz5tbkL87aqpFBvAVFS0OFzVLbApaJ6nv2jZNfidbPCN5SMeNoa2kNTC+MZNQ
qORgr77TgaRuFap5dSun9qci0ll5Y/zBQHb8/Xihah8YpkbO/8SV1aFWLtWKQiQL
xrFTD9ShXC7S6IQrdGcngUhLShinWmjveZJ//B7no3wUP3xPF+EkTXTq5QyD/SfQ
/w0BUosy55sCn5OyP6iSJ4cqujA7WCEd48XpS5zgceqdnE98mvjhriMDfORXsSnx
ymsDnPyX/wkrLOynpN5KPoCMwXS0knTSGADUPMm/EQXa5fjEDxg0OnJQcJV/mZJk
4Rgi+gZJAgMBAAECggEAXWW2ob3u1POL/gIDninmOqStd0L+jnEHPCFfar0nJU5x
z6usJr4JcqcA0MNUXRQCl9gh/MCBfCCqJKG7PrBE9BDIi8ZROyN6xJAzMbi8VOiB
uucVnAFjak97v4ctmVeDcEFWkG0UVyrF6L82LZ9rAiGBMg5jvqStPWP1AskHUmM+
drJZYsrvqqqZLVbB34I6logBcD0IKEib8uBM3brLrO1t86XpLvOIJEjCsD+GRtmJ
vnjjVcIqHFM+VyA+RvpgMnfbTcG3D4YZhDtdgsbnOHs4mydM/I6C7a5pLftKsdoC
lKgb6CIqJoXvW2goljHfQiBre56hwhBjRgzmdo+BoQKBgQD+aM+P1+nAsIUxWmXf
jexL8LIQ+POxNztz/d9EIYLXOztiS3epwt8e/Xqffu4B5+D1j6u/gtnK90SNcMi2
IhetBwkTGdz6s/JHTt92f7okRbSnzwZwj03Ppkgg3VAPp3LL4nP/a1kcHi+9aqBB
kPPkQ0k9BS/JOmEVPOKpOx0ABwKBgQDNJOjDyhKesjjLK33xJdD6sdhwxFJrSMs3
TtMd0KvPu26Z+gn+5ybE/rzOoO7YZIUmB8AXvNKLu8U/5FRE6eOeaumz2LvFNCzE
IC6J2Oixt/lVxxuR0mSRTJI/hR0CtrofXRke8YjeU3VjtkMW6k2QrPAAMMTPieRW
fkfb8oWTLwKBgQCt0B/W77XFLxSgplkphgYlz/loPR4JOmoFEjLCkn6Y29/zhQnp
UrkrrBRl+ctUQ/7e5lx5yEVSNOOCGscWIG66iS76/NWL9vsVGt7zT8p106XcbEXD
CzUnJDztLybuuwFkKIAFxmqoGjuVls6MXSM0FYBpDy0Ztyfy4Zkd88QZawKBgHB8
rKWvSEZ8s2e0kXqJoe3VVzl+bTMm10eckWbn5U4jGKKV2KVNWpTqmd0zocRGWjxg
Q5TAlTLJ438FVK/1EDrtpPhY/51C3sksXFh5+B57It1GMHflRf/mXMs30pCKYcSQ
6BVvm/1NBjGG34LRN3b9XRy9oS2sDujelcilU1lBAoGBAMHx0oFSlNUHGofigo3p
26kWJw7BE/o8HVIrfI2vekXOgTVgJWGiubjp3qUKgNf1lbRy/Ur9F4ElW7hEINjC
aMJUZWc/xYwAmKHe/7ITZByfRXdGMwlvo8QudbzDC6vvCqn7wQW56kyTTEEHF54k
21jK19EBX4JWibzJotv8ShkU
-----END PRIVATE KEY-----

View File

@@ -0,0 +1,9 @@
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAy96cRO7iN0WnEI78hc+b
W5C/O2qqRQbwFRUtDhc1S2wKWiep79o2TX4nWzwjeUjHjaGtpDUwvjGTUKjkYK++
04GkbhWqeXUrp/anItJZeWP8wUB2/P14oWofGKZGzv/EldWhVi7VikIkC8axUw/U
oVwu0uiEK3RnJ4FIS0oYp1po73mSf/we56N8FD98TxfhJE106uUMg/0n0P8NAVKL
MuebAp+Tsj+okieHKrowO1ghHePF6Uuc4HHqnZxPfJr44a4jA3zkV7Ep8cprA5z8
l/8JKyzsp6TeSj6AjMF0tJJ00hgA1DzJvxEF2uX4xA8YNDpyUHCVf5mSZOEYIvoG
SQIDAQAB
-----END PUBLIC KEY-----

View File

@@ -0,0 +1,97 @@
#include "agw_test.h"
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include "crypto/aes.h"
#include "crypto/hash.h"
#include "crypto/rsa.h"
#include "protocol/keys.h"
#include "util/base64.h"
#include "util/json.h"
using namespace agw;
namespace {
std::vector<std::uint8_t> bytesOf(const std::string &s)
{
return std::vector<std::uint8_t>(s.begin(), s.end());
}
std::string toStr(const std::vector<std::uint8_t> &v)
{
return std::string(v.begin(), v.end());
}
std::string readFile(const std::string &path)
{
std::ifstream f(path, std::ios::binary);
std::ostringstream ss;
ss << f.rdbuf();
return ss.str();
}
}
int main()
{
namespace k = protocol::keys;
const auto key = crypto::fromHex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f");
const auto iv = crypto::fromHex("101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f");
const auto salt = crypto::fromHex("a0a1a2a3a4a5a6a7");
const std::string payload = "{\"hello\":\"world\"}";
util::Json keysJson;
keysJson[k::aesKey] = util::base64Encode(key);
keysJson[k::aesIv] = util::base64Encode(iv);
keysJson[k::aesSalt] = util::base64Encode(salt);
const std::string keysSerialized = util::qtIndentedDump(keysJson);
const std::string expectedKeysJson =
"{\n"
" \"aes_iv\": \"EBESExQVFhcYGRobHB0eHyAhIiMkJSYnKCkqKywtLi8=\",\n"
" \"aes_key\": \"AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=\",\n"
" \"aes_salt\": \"oKGio6Slpqc=\"\n"
"}\n";
CHECK_EQ(keysSerialized, expectedKeysJson);
const auto apiCipher = crypto::aesEncryptCbc(bytesOf(payload), key, iv);
const std::string apiPayloadB64 = util::base64Encode(apiCipher);
CHECK_EQ(apiPayloadB64, std::string("2WHnAcP2N+l+jz7fbKyO46jjYUq7h98lxlTIT6K0Xg8="));
const std::string pub = readFile(std::string(AGW_FIXTURES_DIR) + "/test_rsa_pub.pem");
const std::string priv = readFile(std::string(AGW_FIXTURES_DIR) + "/test_rsa_priv.pem");
CHECK(!pub.empty());
CHECK(!priv.empty());
const auto keyCipher = crypto::rsaEncryptPublicPkcs1(bytesOf(keysSerialized), pub);
const std::string keyPayloadB64 = util::base64Encode(keyCipher);
const auto keyCipherBack = util::base64Decode(keyPayloadB64);
const auto recovered = crypto::rsaDecryptPrivatePkcs1(keyCipherBack, priv);
CHECK_EQ(toStr(recovered), keysSerialized);
util::Json body;
body[k::keyPayload] = keyPayloadB64;
body[k::apiPayload] = apiPayloadB64;
const std::string bodySerialized = util::qtIndentedDump(body);
util::Json parsed = util::Json::parse(bodySerialized);
CHECK_EQ(parsed[k::apiPayload].get<std::string>(), apiPayloadB64);
{
const auto cBack = util::base64Decode(parsed[k::keyPayload].get<std::string>());
const auto rec = crypto::rsaDecryptPrivatePkcs1(cBack, priv);
CHECK_EQ(toStr(rec), keysSerialized);
}
{
const auto respPlain = bytesOf("{\"ok\":true}");
const auto respCipher = crypto::aesEncryptCbc(respPlain, key, iv);
const auto back = crypto::aesDecryptCbc(respCipher, key, iv);
CHECK(back == respPlain);
}
return AGW_TEST_MAIN_RETURN();
}

View File

@@ -0,0 +1,174 @@
#include "agw_test.h"
#include <atomic>
#include <chrono>
#include <fstream>
#include <future>
#include <memory>
#include <sstream>
#include <string>
#include <thread>
#include <vector>
#include "agw/cancellation.h"
#include "agw/gateway_controller.h"
#include "agw/config.h"
#include "crypto/aes.h"
#include "crypto/rsa.h"
#include "mock_gateway/mock_gateway.h"
#include "protocol/keys.h"
#include "util/base64.h"
#include "util/json.h"
using namespace agw;
namespace {
std::string readFile(const std::string &path)
{
std::ifstream f(path, std::ios::binary);
std::ostringstream ss;
ss << f.rdbuf();
return ss.str();
}
Config baseConfig(std::shared_ptr<IHttpClient> http, const std::string &pub)
{
Config c;
c.gatewayEndpoint = "gw.example.test";
c.agwPublicKeyPem = pub;
c.requestTimeoutMsecs = 5000;
c.httpClient = std::move(http);
return c;
}
class BlockingUntilCancelMock : public IHttpClient {
public:
std::atomic<int> entered{0};
HttpResponse send(const HttpRequest &req) override
{
entered.fetch_add(1);
while (!(req.cancelCheck && req.cancelCheck())) {
std::this_thread::sleep_for(std::chrono::milliseconds(2));
}
HttpResponse r;
r.error = TransportError::Canceled;
return r;
}
};
class StatelessMock : public IHttpClient {
public:
explicit StatelessMock(std::string priv) : m_priv(std::move(priv)) {}
std::atomic<int> count{0};
HttpResponse send(const HttpRequest &req) override
{
count.fetch_add(1);
namespace k = protocol::keys;
util::Json body = util::Json::parse(req.body);
const auto keyCipher = util::base64Decode(body[k::keyPayload].get<std::string>());
const auto keysBytes = crypto::rsaDecryptPrivatePkcs1(keyCipher, m_priv);
util::Json keysJson = util::Json::parse(std::string(keysBytes.begin(), keysBytes.end()));
const auto aesKey = util::base64Decode(keysJson[k::aesKey].get<std::string>());
const auto aesIv = util::base64Decode(keysJson[k::aesIv].get<std::string>());
const std::string plain = R"({"ok":true})";
const std::vector<std::uint8_t> pv(plain.begin(), plain.end());
const auto cipher = crypto::aesEncryptCbc(pv, aesKey, aesIv);
HttpResponse r;
r.httpStatusCode = 200;
r.body.assign(cipher.begin(), cipher.end());
return r;
}
private:
std::string m_priv;
};
}
int main()
{
const std::string pub = readFile(std::string(AGW_FIXTURES_DIR) + "/test_rsa_pub.pem");
const std::string priv = readFile(std::string(AGW_FIXTURES_DIR) + "/test_rsa_priv.pem");
const std::string endpoint = "https://%1/api/v1/test";
const FailoverContext ctx{"prem", "US"};
const std::string payload = R"({"hello":"world"})";
{
auto mock = std::make_shared<agw_test::MockGateway>(priv);
mock->responsePlain = R"({"ok":true,"v":1})";
GatewayController client(baseConfig(mock, pub));
std::future<Response> f = client.postFuture(endpoint, payload, ctx);
Response r = f.get();
CHECK(r.error == ErrorCode::NoError);
CHECK_EQ(r.body, std::string(R"({"ok":true,"v":1})"));
CHECK_EQ(mock->lastDecryptedPayload, payload);
}
{
auto mock = std::make_shared<agw_test::MockGateway>(priv);
mock->responsePlain = R"({"async":true})";
GatewayController client(baseConfig(mock, pub));
std::promise<Response> p;
std::future<Response> f = p.get_future();
client.postAsync(
endpoint, payload, [&p](Response r) { p.set_value(std::move(r)); }, ctx);
Response r = f.get();
CHECK(r.error == ErrorCode::NoError);
CHECK_EQ(r.body, std::string(R"({"async":true})"));
}
{
auto mock = std::make_shared<agw_test::MockGateway>(priv);
GatewayController client(baseConfig(mock, pub));
CancellationToken token;
token.cancel();
std::future<Response> f = client.postFuture(endpoint, payload, ctx, &token);
Response r = f.get();
CHECK(r.error == ErrorCode::Cancelled);
CHECK(mock->requestCount == 0);
}
{
auto mock = std::make_shared<BlockingUntilCancelMock>();
GatewayController client(baseConfig(mock, pub));
CancellationToken token;
std::promise<Response> p;
std::future<Response> f = p.get_future();
client.postAsync(
endpoint, payload, [&p](Response r) { p.set_value(std::move(r)); }, ctx, &token);
while (mock->entered.load() == 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(2));
}
token.cancel();
Response r = f.get();
CHECK(r.error == ErrorCode::Cancelled);
}
{
auto mock = std::make_shared<StatelessMock>(priv);
Config cfg = baseConfig(mock, pub);
cfg.threadPoolSize = 8;
GatewayController client(std::move(cfg));
constexpr int N = 64;
std::vector<std::future<Response>> futs;
futs.reserve(N);
for (int i = 0; i < N; ++i) {
futs.push_back(client.postFuture(endpoint, payload, ctx));
}
int ok = 0;
for (auto &fut : futs) {
Response r = fut.get();
if (r.error == ErrorCode::NoError && r.body == R"({"ok":true})") {
++ok;
}
}
CHECK(ok == N);
CHECK(mock->count.load() == N);
}
return AGW_TEST_MAIN_RETURN();
}

View File

@@ -0,0 +1,118 @@
#include "agw_test.h"
#include <fstream>
#include <future>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include "agw/c_abi.h"
#include "agw/types.h"
#include "detail/test_hooks.h"
#include "mock_gateway/mock_gateway.h"
namespace {
std::string readFile(const std::string &path)
{
std::ifstream f(path, std::ios::binary);
std::ostringstream ss;
ss << f.rdbuf();
return ss.str();
}
agw_config makeConfig(const char *gateway, const char *pem)
{
agw_config c{};
c.gateway_endpoint = gateway;
c.agw_public_key_pem = pem;
c.request_timeout_msecs = 5000;
return c;
}
struct AsyncSink {
std::promise<std::pair<int, std::string>> promise;
};
void asyncCallback(agw_response r, void *ud)
{
auto *sink = static_cast<AsyncSink *>(ud);
sink->promise.set_value({r.error, r.body ? std::string(r.body, r.body_len) : std::string()});
agw_response_free(&r);
}
}
int main()
{
const std::string pub = readFile(std::string(AGW_FIXTURES_DIR) + "/test_rsa_pub.pem");
const std::string priv = readFile(std::string(AGW_FIXTURES_DIR) + "/test_rsa_priv.pem");
const std::string payload = R"({"hello":"world"})";
{
auto mock = std::make_shared<agw_test::MockGateway>(priv);
mock->responsePlain = R"({"ok":true,"c":1})";
agw::detail::setNextTestHttpClient(mock);
agw_config cfg = makeConfig("gw.example.test", pub.c_str());
agw_client *client = agw_client_create(&cfg);
CHECK(client != nullptr);
agw_response r = agw_client_post(client, "https://%1/api/v1/test", payload.c_str(), "prem", "US", nullptr);
CHECK(r.error == 0);
CHECK(r.body != nullptr);
CHECK_EQ(std::string(r.body, r.body_len), std::string(R"({"ok":true,"c":1})"));
CHECK_EQ(mock->lastDecryptedPayload, payload);
agw_response_free(&r);
CHECK(r.body == nullptr);
mock->responsePlain = R"({"async":1})";
AsyncSink sink;
auto fut = sink.promise.get_future();
agw_client_post_async(client, "https://%1/api/v1/test", payload.c_str(), "prem", "US",
&asyncCallback, &sink, nullptr);
auto [err, body] = fut.get();
CHECK(err == 0);
CHECK_EQ(body, std::string(R"({"async":1})"));
agw_client_destroy(client);
}
{
auto mock = std::make_shared<agw_test::MockGateway>(priv);
agw::detail::setNextTestHttpClient(mock);
agw_config cfg = makeConfig("gw.example.test", pub.c_str());
agw_client *client = agw_client_create(&cfg);
agw_cancel_token *token = agw_cancel_token_create();
CHECK(token != nullptr);
agw_cancel_token_cancel(token);
agw_response r = agw_client_post(client, "https://%1/api/v1/test", payload.c_str(), "", "", token);
CHECK(r.error == static_cast<int>(agw::ErrorCode::Cancelled));
CHECK(mock->requestCount == 0);
agw_response_free(&r);
agw_cancel_token_destroy(token);
agw_client_destroy(client);
}
{
agw_config cfg = makeConfig("gw.example.test", "not a pem");
agw_client *client = agw_client_create(&cfg);
CHECK(client != nullptr);
agw_response r = agw_client_post(client, "https://%1/x", payload.c_str(), "", "", nullptr);
CHECK(r.error == static_cast<int>(agw::ErrorCode::ApiMissingAgwPublicKey));
agw_response_free(&r);
agw_client_destroy(client);
}
{
CHECK(agw_client_create(nullptr) == nullptr);
agw_response r = agw_client_post(nullptr, "e", "p", "", "", nullptr);
CHECK(r.error != 0);
agw_response_free(&r);
agw_client_destroy(nullptr);
}
return AGW_TEST_MAIN_RETURN();
}

View File

@@ -0,0 +1,125 @@
#include "agw_test.h"
#include <fstream>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "agw/gateway_controller.h"
#include "agw/config.h"
#include "crypto/aes.h"
#include "crypto/rsa.h"
#include "protocol/keys.h"
#include "util/base64.h"
#include "util/json.h"
using namespace agw;
namespace {
std::string readFile(const std::string &path)
{
std::ifstream f(path, std::ios::binary);
std::ostringstream ss;
ss << f.rdbuf();
return ss.str();
}
bool contains(const std::string &h, const std::string &n) { return h.find(n) != std::string::npos; }
class FailoverMock : public IHttpClient {
public:
explicit FailoverMock(std::string priv) : m_priv(std::move(priv)) {}
int directPosts = 0, proxyPosts = 0, storageGets = 0, healthGets = 0;
HttpResponse send(const HttpRequest &req) override
{
HttpResponse resp;
resp.httpStatusCode = 200;
if (req.method == "GET") {
if (contains(req.url, "lmbd-health")) {
++healthGets;
if (!contains(req.url, "proxy.good.test")) {
resp.error = TransportError::ConnectionError;
}
return resp;
}
++storageGets;
resp.body = R"(["https://proxy.good.test/"])";
return resp;
}
namespace k = protocol::keys;
util::Json body = util::Json::parse(req.body);
const auto keyCipher = util::base64Decode(body[k::keyPayload].get<std::string>());
const auto keysBytes = crypto::rsaDecryptPrivatePkcs1(keyCipher, m_priv);
util::Json keysJson = util::Json::parse(std::string(keysBytes.begin(), keysBytes.end()));
const auto aesKey = util::base64Decode(keysJson[k::aesKey].get<std::string>());
const auto aesIv = util::base64Decode(keysJson[k::aesIv].get<std::string>());
std::string plain;
if (contains(req.url, "proxy.good.test")) {
++proxyPosts;
plain = R"({"ok":true,"via":"proxy"})";
} else {
++directPosts;
plain = R"({"http_status":404,"message":"blocked"})";
}
const std::vector<std::uint8_t> pv(plain.begin(), plain.end());
const auto cipher = crypto::aesEncryptCbc(pv, aesKey, aesIv);
resp.body.assign(cipher.begin(), cipher.end());
return resp;
}
private:
std::string m_priv;
};
}
int main()
{
const std::string pub = readFile(std::string(AGW_FIXTURES_DIR) + "/test_rsa_pub.pem");
const std::string priv = readFile(std::string(AGW_FIXTURES_DIR) + "/test_rsa_priv.pem");
auto mock = std::make_shared<FailoverMock>(priv);
Config cfg;
cfg.gatewayEndpoint = "https://gw.example.test/";
cfg.agwPublicKeyPem = pub;
cfg.isDevEnvironment = true;
cfg.s3PrimaryEndpoints = {"https://s3.example.test/"};
cfg.requestTimeoutMsecs = 5000;
cfg.httpClient = mock;
GatewayController client(std::move(cfg));
const std::string endpoint = "%1api/v1/test";
const FailoverContext ctx{"prem", "US"};
const std::string payload = R"({"hello":"world"})";
{
Response r = client.post(endpoint, payload, ctx);
CHECK(r.error == ErrorCode::NoError);
CHECK_EQ(r.body, std::string(R"({"ok":true,"via":"proxy"})"));
CHECK(mock->directPosts == 1);
CHECK(mock->storageGets >= 1);
CHECK(mock->healthGets >= 1);
CHECK(mock->proxyPosts == 1);
}
{
const int storageBefore = mock->storageGets;
const int healthBefore = mock->healthGets;
Response r = client.post(endpoint, payload, ctx);
CHECK(r.error == ErrorCode::NoError);
CHECK_EQ(r.body, std::string(R"({"ok":true,"via":"proxy"})"));
CHECK(mock->storageGets == storageBefore);
CHECK(mock->healthGets == healthBefore);
CHECK(mock->directPosts == 1);
}
return AGW_TEST_MAIN_RETURN();
}

View File

@@ -0,0 +1,103 @@
#include "agw_test.h"
#include <fstream>
#include <memory>
#include <sstream>
#include <string>
#include "agw/gateway_controller.h"
#include "agw/config.h"
#include "mock_gateway/mock_gateway.h"
using namespace agw;
namespace {
std::string readFile(const std::string &path)
{
std::ifstream f(path, std::ios::binary);
std::ostringstream ss;
ss << f.rdbuf();
return ss.str();
}
Config baseConfig(std::shared_ptr<IHttpClient> http, const std::string &pubPem)
{
Config c;
c.gatewayEndpoint = "gw.example.test";
c.agwPublicKeyPem = pubPem;
c.requestTimeoutMsecs = 5000;
c.httpClient = std::move(http);
return c;
}
}
int main()
{
const std::string pub = readFile(std::string(AGW_FIXTURES_DIR) + "/test_rsa_pub.pem");
const std::string priv = readFile(std::string(AGW_FIXTURES_DIR) + "/test_rsa_priv.pem");
const std::string endpoint = "https://%1/api/v1/test";
const FailoverContext ctx{"premium", "US"};
const std::string payload = R"({"hello":"world","n":42})";
{
auto mock = std::make_shared<agw_test::MockGateway>(priv);
mock->responsePlain = R"({"ok":true,"data":"hi"})";
std::string seenHost;
Config cfg = baseConfig(mock, pub);
cfg.onBeforeRequest = [&](const std::string &h) { seenHost = h; };
GatewayController client(std::move(cfg));
Response r = client.post(endpoint, payload, ctx);
CHECK(r.error == ErrorCode::NoError);
CHECK_EQ(r.body, std::string(R"({"ok":true,"data":"hi"})"));
CHECK_EQ(mock->lastDecryptedPayload, payload);
CHECK_EQ(mock->lastUrl, std::string("https://gw.example.test/api/v1/test"));
CHECK_EQ(seenHost, std::string("gw.example.test"));
CHECK(mock->requestCount == 1);
CHECK(mock->lastRequestId.size() == 36);
CHECK(mock->lastRequestId[14] == '4');
}
{
auto mock = std::make_shared<agw_test::MockGateway>(priv);
mock->responsePlain = R"({"http_status":409,"message":"limit"})";
GatewayController client(baseConfig(mock, pub));
Response r = client.post(endpoint, payload, ctx);
CHECK(r.error == ErrorCode::ApiConfigLimitError);
CHECK_EQ(r.body, std::string(R"({"http_status":409,"message":"limit"})"));
}
{
auto mock = std::make_shared<agw_test::MockGateway>(priv);
mock->simulateSsl = true;
GatewayController client(baseConfig(mock, pub));
Response r = client.post(endpoint, payload, ctx);
CHECK(r.error == ErrorCode::ApiConfigSslError);
}
{
auto mock = std::make_shared<agw_test::MockGateway>(priv);
mock->simulateTransport = TransportError::Timeout;
GatewayController client(baseConfig(mock, pub));
Response r = client.post(endpoint, payload, ctx);
CHECK(r.error == ErrorCode::ApiConfigTimeoutError);
}
{
auto mock = std::make_shared<agw_test::MockGateway>(priv);
Config cfg = baseConfig(mock, "not a pem key");
GatewayController client(std::move(cfg));
Response r = client.post(endpoint, payload, ctx);
CHECK(r.error == ErrorCode::ApiMissingAgwPublicKey);
CHECK(mock->requestCount == 0);
}
return AGW_TEST_MAIN_RETURN();
}

View File

@@ -0,0 +1,77 @@
#ifndef AGW_TEST_MOCK_GATEWAY_H
#define AGW_TEST_MOCK_GATEWAY_H
#include <string>
#include <vector>
#include "agw/http.h"
#include "crypto/aes.h"
#include "crypto/rsa.h"
#include "protocol/keys.h"
#include "util/base64.h"
#include "util/json.h"
namespace agw_test {
class MockGateway : public agw::IHttpClient {
public:
explicit MockGateway(std::string privateKeyPem) : m_priv(std::move(privateKeyPem)) {}
std::string responsePlain = "{\"ok\":true}";
bool simulateSsl = false;
agw::TransportError simulateTransport = agw::TransportError::None;
int httpStatusCode = 200;
std::string lastUrl;
std::string lastRequestId;
std::string lastDecryptedPayload;
int requestCount = 0;
agw::HttpResponse send(const agw::HttpRequest &req) override
{
++requestCount;
lastUrl = req.url;
for (const auto &h : req.headers) {
if (h.first == "X-Client-Request-ID") {
lastRequestId = h.second;
}
}
agw::HttpResponse resp;
resp.httpStatusCode = httpStatusCode;
if (simulateSsl) {
resp.sslError = true;
resp.error = agw::TransportError::ConnectionError;
return resp;
}
if (simulateTransport != agw::TransportError::None) {
resp.error = simulateTransport;
return resp;
}
namespace k = agw::protocol::keys;
agw::util::Json body = agw::util::Json::parse(req.body);
const auto keyCipher = agw::util::base64Decode(body[k::keyPayload].get<std::string>());
const auto keysBytes = agw::crypto::rsaDecryptPrivatePkcs1(keyCipher, m_priv);
agw::util::Json keysJson = agw::util::Json::parse(std::string(keysBytes.begin(), keysBytes.end()));
const auto aesKey = agw::util::base64Decode(keysJson[k::aesKey].get<std::string>());
const auto aesIv = agw::util::base64Decode(keysJson[k::aesIv].get<std::string>());
const auto apiCipher = agw::util::base64Decode(body[k::apiPayload].get<std::string>());
const auto payloadBytes = agw::crypto::aesDecryptCbc(apiCipher, aesKey, aesIv);
lastDecryptedPayload.assign(payloadBytes.begin(), payloadBytes.end());
const std::vector<std::uint8_t> respPlain(responsePlain.begin(), responsePlain.end());
const auto respCipher = agw::crypto::aesEncryptCbc(respPlain, aesKey, aesIv);
resp.body.assign(respCipher.begin(), respCipher.end());
resp.error = agw::TransportError::None;
return resp;
}
private:
std::string m_priv;
};
}
#endif

24765
agw-sdk/tests/third_party/nlohmann/json.hpp vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,48 @@
#include "agw_test.h"
#include <string>
#include "failover/bypass_policy.h"
using namespace agw;
using agw::failover::shouldBypassProxy;
namespace {
bool bypassBody(const std::string &body)
{
return shouldBypassProxy(TransportError::None, body, true);
}
}
int main()
{
CHECK(shouldBypassProxy(TransportError::None, "garbage", false) == true);
CHECK(shouldBypassProxy(TransportError::Timeout, R"({"http_status":200})", true) == true);
CHECK(shouldBypassProxy(TransportError::Canceled, R"({"http_status":200})", true) == true);
CHECK(bypassBody("<html><body>blocked</body></html>") == true);
CHECK(bypassBody(R"({"http_status":408})") == false);
CHECK(bypassBody(R"({"http_status":409})") == false);
CHECK(bypassBody(R"({"http_status":402})") == false);
CHECK(bypassBody(R"({"http_status":404,"message":"whatever"})") == true);
CHECK(bypassBody(R"({"http_status":404,"message":"No active configuration found for x"})") == false);
CHECK(bypassBody(R"({"http_status":404,"detail":"Account not found."})") == false);
CHECK(bypassBody(R"({"http_status":404,"message":"Session not found"})") == false);
CHECK(bypassBody(R"({"http_status":501})") == true);
CHECK(bypassBody(R"({"http_status":501,"message":"client version update is required"})") == false);
CHECK(bypassBody(R"({"http_status":422,"message":"Failed to retrieve subscription information. Is it activated?"})") == false);
CHECK(bypassBody(R"({"http_status":422,"message":"other"})") == true);
CHECK(bypassBody(R"({"http_status":200})") == false);
CHECK(bypassBody("plain ok") == false);
CHECK(shouldBypassProxy(TransportError::ConnectionError, R"({"http_status":200})", true) == true);
return AGW_TEST_MAIN_RETURN();
}

View File

@@ -0,0 +1,123 @@
#include "agw_test.h"
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include "crypto/aes.h"
#include "crypto/hash.h"
#include "crypto/rng.h"
#include "crypto/rsa.h"
#include "util/base64.h"
#include "util/uuid.h"
using namespace agw;
namespace {
std::vector<std::uint8_t> bytesOf(const std::string &s)
{
return std::vector<std::uint8_t>(s.begin(), s.end());
}
std::string readFile(const std::string &path)
{
std::ifstream f(path, std::ios::binary);
std::ostringstream ss;
ss << f.rdbuf();
return ss.str();
}
class FixedRng : public crypto::IRng {
public:
explicit FixedRng(std::vector<std::uint8_t> data) : m_data(std::move(data)) {}
std::vector<std::uint8_t> bytes(std::size_t n) override
{
std::vector<std::uint8_t> out(n);
for (std::size_t i = 0; i < n; ++i) {
out[i] = m_data[(m_pos + i) % m_data.size()];
}
m_pos += n;
return out;
}
private:
std::vector<std::uint8_t> m_data;
std::size_t m_pos = 0;
};
}
int main()
{
CHECK_EQ(util::base64Encode(std::string("")), std::string(""));
CHECK_EQ(util::base64Encode(std::string("f")), std::string("Zg=="));
CHECK_EQ(util::base64Encode(std::string("fo")), std::string("Zm8="));
CHECK_EQ(util::base64Encode(std::string("foo")), std::string("Zm9v"));
CHECK_EQ(util::base64Encode(std::string("foob")), std::string("Zm9vYg=="));
CHECK_EQ(util::base64Encode(std::string("fooba")), std::string("Zm9vYmE="));
CHECK_EQ(util::base64Encode(std::string("foobar")), std::string("Zm9vYmFy"));
{
std::vector<std::uint8_t> v{0xfb, 0xff, 0xbf};
CHECK_EQ(util::base64UrlEncodeNoPad(v), std::string("-_-_"));
CHECK_EQ(util::base64Encode(v), std::string("+/+/"));
}
{
auto v = bytesOf("any carnal pleasure.");
CHECK(util::base64Decode(util::base64Encode(v)) == v);
CHECK(util::base64Decode(util::base64UrlEncodeNoPad(v)) == v);
}
CHECK_EQ(crypto::toHex(crypto::sha512(bytesOf("abc"))),
std::string("ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a"
"2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f"));
{
std::vector<std::uint8_t> v{0x00, 0x01, 0xab, 0xff};
CHECK_EQ(crypto::toHex(v), std::string("0001abff"));
CHECK(crypto::fromHex("0001abff") == v);
}
{
auto key = crypto::fromHex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f");
auto iv = crypto::fromHex("101112131415161718191a1b1c1d1e1f");
auto pt = bytesOf("{\"hello\":\"world\"}");
auto ct = crypto::aesEncryptCbc(pt, key, iv);
CHECK_EQ(util::base64Encode(ct), std::string("2WHnAcP2N+l+jz7fbKyO46jjYUq7h98lxlTIT6K0Xg8="));
CHECK(crypto::aesDecryptCbc(ct, key, iv) == pt);
}
{
auto key = crypto::fromHex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f");
auto iv16 = crypto::fromHex("101112131415161718191a1b1c1d1e1f");
auto iv32 = crypto::fromHex("101112131415161718191a1b1c1d1e1fdeadbeefdeadbeefdeadbeefdeadbeef");
auto pt = bytesOf("{\"hello\":\"world\"}");
CHECK(crypto::aesEncryptCbc(pt, key, iv16) == crypto::aesEncryptCbc(pt, key, iv32));
}
{
std::string pub = readFile(std::string(AGW_FIXTURES_DIR) + "/test_rsa_pub.pem");
std::string priv = readFile(std::string(AGW_FIXTURES_DIR) + "/test_rsa_priv.pem");
CHECK(!pub.empty());
CHECK(!priv.empty());
auto msg = bytesOf("{\"aes_key\":\"...\",\"aes_iv\":\"...\",\"aes_salt\":\"...\"}");
auto ct = crypto::rsaEncryptPublicPkcs1(msg, pub);
auto rt = crypto::rsaDecryptPrivatePkcs1(ct, priv);
CHECK(rt == msg);
auto ct2 = crypto::rsaEncryptPublicPkcs1(msg, pub);
CHECK(ct != ct2);
}
{
FixedRng rng(std::vector<std::uint8_t>(16, 0xFF));
std::string u = util::makeUuidV4(rng);
CHECK_EQ(u, std::string("ffffffff-ffff-4fff-bfff-ffffffffffff"));
CHECK(u.size() == 36);
CHECK(u[14] == '4');
CHECK(u[19] == '8' || u[19] == '9' || u[19] == 'a' || u[19] == 'b');
}
return AGW_TEST_MAIN_RETURN();
}

View File

@@ -0,0 +1,54 @@
#include "agw_test.h"
#include <string>
#include "protocol/error_mapping.h"
using namespace agw;
using agw::protocol::mapResponseError;
namespace {
int code(ErrorCode e) { return static_cast<int>(e); }
ErrorCode mapBody(const std::string &body)
{
return mapResponseError(false, TransportError::None, body);
}
}
int main()
{
CHECK(mapResponseError(true, TransportError::None, "") == ErrorCode::ApiConfigSslError);
CHECK(mapResponseError(false, TransportError::Timeout, "") == ErrorCode::ApiConfigTimeoutError);
CHECK(mapResponseError(false, TransportError::Canceled, "") == ErrorCode::ApiConfigTimeoutError);
CHECK(mapResponseError(false, TransportError::OperationNotImplemented, "") == ErrorCode::ApiUpdateRequestError);
CHECK(mapResponseError(false, TransportError::ConnectionError, "") == ErrorCode::ApiConfigDownloadError);
CHECK(mapResponseError(false, TransportError::None, "") == ErrorCode::NoError);
CHECK(mapBody("not a json") == ErrorCode::NoError);
CHECK(mapBody(R"({"http_status":429})") == ErrorCode::ApiRateLimitError);
CHECK(mapBody(R"({"http_status":409})") == ErrorCode::ApiConfigLimitError);
CHECK(mapBody(R"({"http_status":409,"message":"Trial Subscription Already Used"})") == ErrorCode::ApiTrialAlreadyUsedError);
CHECK(mapBody(R"({"http_status":404})") == ErrorCode::ApiNotFoundError);
CHECK(mapBody(R"({"http_status":408})") == ErrorCode::ApiConfigTimeoutError);
CHECK(mapBody(R"({"http_status":501})") == ErrorCode::ApiUpdateRequestError);
CHECK(mapBody(R"({"http_status":422,"message":"Failed to retrieve subscription information. Is it activated?"})")
== ErrorCode::ApiSubscriptionExpiredError);
CHECK(mapBody(R"({"http_status":422,"message":"something else"})") == ErrorCode::ApiConfigDownloadError);
CHECK(mapBody(R"({"http_status":402,"message":"refresh_captcha"})") == ErrorCode::ApiCaptchaRefreshError);
CHECK(mapBody(R"({"http_status":402,"message":"invalid_captcha"})") == ErrorCode::ApiCaptchaInvalidError);
CHECK(mapBody(R"({"http_status":402,"captcha_id":"x"})") == ErrorCode::ApiCaptchaRequiredError);
CHECK(mapBody(R"({"http_status":402,"captcha_image":"x"})") == ErrorCode::ApiCaptchaRequiredError);
CHECK(mapBody(R"({"http_status":402,"message":"rate_limit_exceeded"})") == ErrorCode::ApiCaptchaRequiredError);
CHECK(mapBody(R"({"http_status":402,"message":"nope"})") == ErrorCode::ApiSubscriptionNotActiveError);
CHECK(mapBody(R"({"http_status":500})") == ErrorCode::ApiConfigDownloadError);
CHECK(mapBody(R"({"http_status":200})") == ErrorCode::NoError);
(void)code;
return AGW_TEST_MAIN_RETURN();
}

View File

@@ -0,0 +1,61 @@
#include "agw_test.h"
#include <string>
#include "util/json.h"
using namespace agw;
int main()
{
{
util::Json j;
j["aes_key"] = "KEY";
j["aes_iv"] = "IV";
j["aes_salt"] = "SALT";
const std::string expected =
"{\n"
" \"aes_iv\": \"IV\",\n"
" \"aes_key\": \"KEY\",\n"
" \"aes_salt\": \"SALT\"\n"
"}\n";
CHECK_EQ(util::qtIndentedDump(j), expected);
}
{
util::Json j;
j["key_payload"] = "K";
j["api_payload"] = "A";
const std::string expected =
"{\n"
" \"api_payload\": \"A\",\n"
" \"key_payload\": \"K\"\n"
"}\n";
CHECK_EQ(util::qtIndentedDump(j), expected);
}
{
util::Json j;
j["s"] = std::string("a\"b\\c\nd\te\x01");
const std::string expected =
"{\n"
" \"s\": \"a\\\"b\\\\c\\nd\\te\\u0001\"\n"
"}\n";
CHECK_EQ(util::qtIndentedDump(j), expected);
}
{
util::Json j;
j["outer"]["inner"] = "v";
const std::string expected =
"{\n"
" \"outer\": {\n"
" \"inner\": \"v\"\n"
" }\n"
"}\n";
CHECK_EQ(util::qtIndentedDump(j), expected);
}
return AGW_TEST_MAIN_RETURN();
}

View File

@@ -0,0 +1,87 @@
#include "agw_test.h"
#include <string>
#include <vector>
#include "crypto/aes.h"
#include "crypto/hash.h"
#include "failover/proxy_list.h"
#include "util/base64.h"
using namespace agw;
namespace {
std::vector<std::uint8_t> bytesOf(const std::string &s)
{
return std::vector<std::uint8_t>(s.begin(), s.end());
}
}
int main()
{
{
const std::vector<std::string> primary{"https://a/", "https://b/"};
const std::vector<std::string> fallback{"https://f/"};
const FailoverContext ctx{"prem", "US"};
const std::string enc =
util::base64UrlEncodeNoPad(bytesOf("endpoints-prem-US"));
const auto urls = failover::buildStorageUrls(primary, fallback, ctx);
const std::vector<std::string> expected{
"https://a/" + enc + ".json",
"https://b/" + enc + ".json",
"https://a/endpoints.json",
"https://b/endpoints.json",
"https://f/" + enc + ".json",
"https://f/endpoints.json",
};
CHECK(urls == expected);
}
{
const std::vector<std::string> primary{"https://a/", "https://b/"};
const std::vector<std::string> fallback{"https://f/"};
const FailoverContext ctx{"", ""};
const auto urls = failover::buildStorageUrls(primary, fallback, ctx);
const std::vector<std::string> expected{
"https://a/endpoints.json",
"https://b/endpoints.json",
"https://f/endpoints.json",
};
CHECK(urls == expected);
}
{
const auto list = failover::decodeProxyList(R"(["https://p1/","https://p2/"])", true, "");
const std::vector<std::string> expected{"https://p1/", "https://p2/"};
CHECK(list == expected);
CHECK(failover::decodeProxyList(R"({"x":1})", true, "").empty());
}
{
const std::string pub = "PUBKEYDATA-pem-like";
const std::string h = crypto::toHex(crypto::sha512(bytesOf(pub)));
const auto key = crypto::fromHex(h.substr(0, 64));
const auto iv = crypto::fromHex(h.substr(64, 32));
const std::string arr = R"(["https://prod1/","https://prod2/"])";
const auto cipher = crypto::aesEncryptCbc(bytesOf(arr), key, iv);
const std::string b64 = util::base64Encode(cipher);
const auto list = failover::decodeProxyList(b64, false, pub);
const std::vector<std::string> expected{"https://prod1/", "https://prod2/"};
CHECK(list == expected);
bool threw = false;
try {
failover::decodeProxyList("###not base64 cipher###", false, pub);
} catch (...) {
threw = true;
}
CHECK(threw);
}
return AGW_TEST_MAIN_RETURN();
}

View File

@@ -0,0 +1,33 @@
#include "agw_test.h"
#include <atomic>
#include <memory>
#include "util/thread_pool.h"
using namespace agw;
int main()
{
{
std::atomic<int> counter{0};
{
util::ThreadPool pool(4);
for (int i = 0; i < 1000; ++i) {
pool.submit([&counter] { counter.fetch_add(1, std::memory_order_relaxed); });
}
}
CHECK(counter.load() == 1000);
}
{
std::atomic<bool> ran{false};
{
util::ThreadPool pool(0);
pool.submit([&ran] { ran.store(true); });
}
CHECK(ran.load());
}
return AGW_TEST_MAIN_RETURN();
}

View File

@@ -46,3 +46,19 @@ list(APPEND LIBS OpenSSL::SSL OpenSSL::Crypto)
find_package(libssh REQUIRED)
list(APPEND LIBS ssh::ssh)
# AGW SDK — транспорт к API-шлюзу (agw-sdk, Qt-free). gatewayControllerAdapter линкуется как адаптер.
# Два режима:
# ON (по умолчанию) — собираем SDK ИЗ ИСХОДНИКОВ через add_subdirectory: можно заходить
# отладчиком внутрь SDK, символы и -g идут из основной сборки;
# OFF — потребляем готовый Conan-пакет agw-sdk (как другие нативные компоненты).
option(AGW_SDK_FROM_SOURCE "Build agw-sdk from source (debuggable) instead of the Conan package" ON)
if(AGW_SDK_FROM_SOURCE)
message(STATUS "agw-sdk: building FROM SOURCE (add_subdirectory) — debuggable")
add_subdirectory(${CMAKE_SOURCE_DIR}/agw-sdk ${CMAKE_BINARY_DIR}/agw-sdk)
list(APPEND LIBS agw::agw)
else()
message(STATUS "agw-sdk: consuming Conan package")
find_package(agw-sdk REQUIRED)
list(APPEND LIBS agw::agw)
endif()

View File

@@ -20,7 +20,7 @@ set(HEADERS ${HEADERS}
${CLIENT_ROOT_DIR}/core/utils/qrCodeUtils.h
${CLIENT_ROOT_DIR}/core/controllers/coreController.h
${CLIENT_ROOT_DIR}/core/controllers/coreSignalHandlers.h
${CLIENT_ROOT_DIR}/core/controllers/gatewayController.h
${CLIENT_ROOT_DIR}/core/controllers/gatewayControllerAdapter.h
${CLIENT_ROOT_DIR}/core/utils/selfhosted/sshSession.h
${CLIENT_ROOT_DIR}/core/controllers/serversController.h
${CLIENT_ROOT_DIR}/core/controllers/selfhosted/usersController.h
@@ -97,7 +97,7 @@ set(SOURCES ${SOURCES}
${CLIENT_ROOT_DIR}/core/utils/qrCodeUtils.cpp
${CLIENT_ROOT_DIR}/core/controllers/coreController.cpp
${CLIENT_ROOT_DIR}/core/controllers/coreSignalHandlers.cpp
${CLIENT_ROOT_DIR}/core/controllers/gatewayController.cpp
${CLIENT_ROOT_DIR}/core/controllers/gatewayControllerAdapter.cpp
${CLIENT_ROOT_DIR}/core/utils/selfhosted/sshSession.cpp
${CLIENT_ROOT_DIR}/core/controllers/serversController.cpp
${CLIENT_ROOT_DIR}/core/controllers/selfhosted/usersController.cpp

View File

@@ -1,6 +1,6 @@
#include "newsController.h"
#include "core/controllers/gatewayController.h"
#include "core/controllers/gatewayControllerAdapter.h"
#include "core/repositories/secureServersRepository.h"
#include "core/utils/constants/apiKeys.h"
#include "core/utils/constants/apiConstants.h"
@@ -74,7 +74,7 @@ QFuture<QPair<ErrorCode, QJsonArray>> NewsController::fetchNews()
return QtFuture::makeReadyFuture(qMakePair(ErrorCode::NoError, QJsonArray()));
}
auto gatewayController = QSharedPointer<GatewayController>::create(
auto gatewayController = QSharedPointer<GatewayControllerAdapter>::create(
m_appSettingsRepository->getGatewayEndpoint(),
m_appSettingsRepository->isDevGatewayEnv(),
apiDefs::requestTimeoutMsecs,

View File

@@ -10,7 +10,7 @@
#include <QSet>
#include <limits>
#include "core/controllers/gatewayController.h"
#include "core/controllers/gatewayControllerAdapter.h"
#include "core/utils/serverConfigUtils.h"
#include "core/utils/constants/apiKeys.h"
#include "core/utils/constants/apiConstants.h"
@@ -241,7 +241,7 @@ ErrorCode ServicesCatalogController::fillAvailableServices(QJsonObject &services
ErrorCode ServicesCatalogController::executeRequest(const QString &endpoint, const QJsonObject &apiPayload, QByteArray &responseBody)
{
GatewayController gatewayController(m_appSettingsRepository->getGatewayEndpoint(), m_appSettingsRepository->isDevGatewayEnv(), apiDefs::requestTimeoutMsecs,
GatewayControllerAdapter gatewayController(m_appSettingsRepository->getGatewayEndpoint(), m_appSettingsRepository->isDevGatewayEnv(), apiDefs::requestTimeoutMsecs,
m_appSettingsRepository->isStrictKillSwitchEnabled());
return gatewayController.post(endpoint, apiPayload, responseBody);
}

View File

@@ -21,7 +21,7 @@
#include "core/utils/constants/apiKeys.h"
#include "core/utils/constants/apiConstants.h"
#include "core/utils/api/apiUtils.h"
#include "core/controllers/gatewayController.h"
#include "core/controllers/gatewayControllerAdapter.h"
#include "core/utils/protocolEnum.h"
#include "core/protocols/protocolUtils.h"
#include "core/utils/constants/configKeys.h"
@@ -211,7 +211,7 @@ void SubscriptionController::updateApiConfigInJson(QJsonObject &serverConfigJson
ErrorCode SubscriptionController::executeRequest(const QString &endpoint, const QJsonObject &apiPayload, QByteArray &responseBody, bool isTestPurchase)
{
GatewayController gatewayController(m_appSettingsRepository->getGatewayEndpoint(isTestPurchase), m_appSettingsRepository->isDevGatewayEnv(isTestPurchase), apiDefs::requestTimeoutMsecs,
GatewayControllerAdapter gatewayController(m_appSettingsRepository->getGatewayEndpoint(isTestPurchase), m_appSettingsRepository->isDevGatewayEnv(isTestPurchase), apiDefs::requestTimeoutMsecs,
m_appSettingsRepository->isStrictKillSwitchEnabled());
return gatewayController.post(endpoint, apiPayload, responseBody);
}
@@ -946,7 +946,7 @@ QFuture<QPair<ErrorCode, QString>> SubscriptionController::getRenewalLink(const
apiPayload[apiDefs::key::cliVersion] = QString(APP_VERSION);
apiPayload[apiDefs::key::subscriptionStatus] = getSubscriptionStatusForRenewal(apiV2->apiConfig);
auto gatewayController = QSharedPointer<GatewayController>::create(m_appSettingsRepository->getGatewayEndpoint(isTestPurchase),
auto gatewayController = QSharedPointer<GatewayControllerAdapter>::create(m_appSettingsRepository->getGatewayEndpoint(isTestPurchase),
m_appSettingsRepository->isDevGatewayEnv(isTestPurchase),
apiDefs::requestTimeoutMsecs,
m_appSettingsRepository->isStrictKillSwitchEnabled());

View File

@@ -1,709 +0,0 @@
#include "gatewayController.h"
#include <algorithm>
#include <functional>
#include <random>
#include <QCryptographicHash>
#include <QJsonArray>
#include <QJsonDocument>
#include <QJsonObject>
#include <QNetworkReply>
#include <QPromise>
#include <QUrl>
#include "QBlockCipher.h"
#include "QRsa.h"
#include "amneziaApplication.h"
#include "core/utils/api/apiUtils.h"
#include "core/utils/constants/apiKeys.h"
#include "core/utils/networkUtilities.h"
#include "core/utils/utilities.h"
#ifdef AMNEZIA_DESKTOP
#include "core/utils/ipcClient.h"
#endif
namespace
{
constexpr QLatin1String errorResponsePattern1("No active configuration found for");
constexpr QLatin1String errorResponsePattern2("No non-revoked public key found for");
constexpr QLatin1String errorResponsePattern3("Account not found.");
constexpr QLatin1String errorResponsePatternQrSessionNotFound("QR session not found");
constexpr QLatin1String errorResponsePatternSessionNotFound("Session not found");
constexpr QLatin1String updateRequestResponsePattern("client version update is required");
constexpr int httpStatusCodeNotFound = 404;
constexpr int httpStatusCodeConflict = 409;
constexpr int httpStatusCodeNotImplemented = 501;
constexpr int httpStatusCodePaymentRequired = 402;
constexpr int httpStatusCodeRequestTimeout = 408;
constexpr int httpStatusCodeUnprocessableEntity = 422;
constexpr QLatin1String unprocessableSubscriptionMessage("Failed to retrieve subscription information. Is it activated?");
constexpr int proxyStorageRequestTimeoutMsecs = 3000;
}
GatewayController::GatewayController(const QString &gatewayEndpoint, const bool isDevEnvironment, const int requestTimeoutMsecs,
const bool isStrictKillSwitchEnabled, QObject *parent)
: QObject(parent),
m_gatewayEndpoint(gatewayEndpoint),
m_isDevEnvironment(isDevEnvironment),
m_requestTimeoutMsecs(requestTimeoutMsecs),
m_isStrictKillSwitchEnabled(isStrictKillSwitchEnabled)
{
}
GatewayController::EncryptedRequestData GatewayController::prepareRequest(const QString &endpoint, const QJsonObject &apiPayload)
{
EncryptedRequestData encRequestData;
encRequestData.errorCode = ErrorCode::NoError;
#ifdef Q_OS_IOS
IosController::Instance()->requestInetAccess();
QThread::msleep(10);
#endif
encRequestData.request.setTransferTimeout(m_requestTimeoutMsecs);
encRequestData.request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
encRequestData.request.setRawHeader(QString("X-Client-Request-ID").toUtf8(), QUuid::createUuid().toString(QUuid::WithoutBraces).toUtf8());
encRequestData.request.setUrl(endpoint.arg(m_proxyUrl.isEmpty() ? m_gatewayEndpoint : m_proxyUrl));
// bypass killSwitch exceptions for API-gateway
#ifdef AMNEZIA_DESKTOP
if (m_isStrictKillSwitchEnabled) {
QString host = QUrl(encRequestData.request.url()).host();
QString ip = NetworkUtilities::getIPAddress(host);
if (!ip.isEmpty()) {
IpcClient::withInterface([&](QSharedPointer<IpcInterfaceReplica> iface) {
QRemoteObjectPendingReply<bool> reply = iface->addKillSwitchAllowedRange(QStringList { ip });
if (!reply.waitForFinished(1000) || !reply.returnValue())
qWarning() << "GatewayController::prepareRequest(): Failed to execute remote addKillSwitchAllowedRange call";
});
}
}
#endif
QSimpleCrypto::QBlockCipher blockCipher;
encRequestData.key = blockCipher.generatePrivateSalt(32);
encRequestData.iv = blockCipher.generatePrivateSalt(32);
encRequestData.salt = blockCipher.generatePrivateSalt(8);
QJsonObject keyPayload;
keyPayload[apiDefs::key::aesKey] = QString(encRequestData.key.toBase64());
keyPayload[apiDefs::key::aesIv] = QString(encRequestData.iv.toBase64());
keyPayload[apiDefs::key::aesSalt] = QString(encRequestData.salt.toBase64());
QByteArray encryptedKeyPayload;
QByteArray encryptedApiPayload;
try {
QSimpleCrypto::QRsa rsa;
EVP_PKEY *publicKey = nullptr;
try {
QByteArray rsaKey = m_isDevEnvironment ? DEV_AGW_PUBLIC_KEY : PROD_AGW_PUBLIC_KEY;
QSimpleCrypto::QRsa rsa;
publicKey = rsa.getPublicKeyFromByteArray(rsaKey);
} catch (...) {
Utils::logException();
qCritical() << "error loading public key from environment variables";
encRequestData.errorCode = ErrorCode::ApiMissingAgwPublicKey;
return encRequestData;
}
encryptedKeyPayload = rsa.encrypt(QJsonDocument(keyPayload).toJson(), publicKey, RSA_PKCS1_PADDING);
EVP_PKEY_free(publicKey);
encryptedApiPayload = blockCipher.encryptAesBlockCipher(QJsonDocument(apiPayload).toJson(), encRequestData.key, encRequestData.iv,
"", encRequestData.salt);
} catch (...) {
Utils::logException();
qCritical() << "error when encrypting the request body";
encRequestData.errorCode = ErrorCode::ApiConfigDecryptionError;
return encRequestData;
}
QJsonObject requestBody;
requestBody[apiDefs::key::keyPayload] = QString(encryptedKeyPayload.toBase64());
requestBody[apiDefs::key::apiPayload] = QString(encryptedApiPayload.toBase64());
encRequestData.requestBody = QJsonDocument(requestBody).toJson();
return encRequestData;
}
GatewayController::DecryptionResult GatewayController::tryDecryptResponseBody(const QByteArray &encryptedResponseBody,
QNetworkReply::NetworkError replyError, const QByteArray &key,
const QByteArray &iv, const QByteArray &salt)
{
DecryptionResult result;
result.decryptedBody = encryptedResponseBody;
result.isDecryptionSuccessful = false;
try {
QSimpleCrypto::QBlockCipher blockCipher;
result.decryptedBody = blockCipher.decryptAesBlockCipher(encryptedResponseBody, key, iv, "", salt);
result.isDecryptionSuccessful = true;
} catch (...) {
result.decryptedBody = encryptedResponseBody;
result.isDecryptionSuccessful = false;
}
return result;
}
ErrorCode GatewayController::post(const QString &endpoint, const QJsonObject apiPayload, QByteArray &responseBody)
{
EncryptedRequestData encRequestData = prepareRequest(endpoint, apiPayload);
if (encRequestData.errorCode != ErrorCode::NoError) {
return encRequestData.errorCode;
}
QNetworkReply *reply = amnApp->networkManager()->post(encRequestData.request, encRequestData.requestBody);
QEventLoop wait;
connect(reply, &QNetworkReply::finished, &wait, &QEventLoop::quit);
QList<QSslError> sslErrors;
connect(reply, &QNetworkReply::sslErrors, [this, &sslErrors](const QList<QSslError> &errors) { sslErrors = errors; });
wait.exec(QEventLoop::ExcludeUserInputEvents);
QByteArray encryptedResponseBody = reply->readAll();
QString replyErrorString = reply->errorString();
auto replyError = reply->error();
int httpStatusCode = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute).toInt();
reply->deleteLater();
auto decryptionResult =
tryDecryptResponseBody(encryptedResponseBody, replyError, encRequestData.key, encRequestData.iv, encRequestData.salt);
if (sslErrors.isEmpty() && shouldBypassProxy(replyError, decryptionResult.decryptedBody, decryptionResult.isDecryptionSuccessful)) {
auto requestFunction = [&encRequestData, &encryptedResponseBody](const QString &url) {
encRequestData.request.setUrl(url);
return amnApp->networkManager()->post(encRequestData.request, encRequestData.requestBody);
};
auto replyProcessingFunction = [&encryptedResponseBody, &replyErrorString, &replyError, &httpStatusCode, &sslErrors, &encRequestData,
&decryptionResult, this](QNetworkReply *reply, const QList<QSslError> &nestedSslErrors) {
encryptedResponseBody = reply->readAll();
replyErrorString = reply->errorString();
replyError = reply->error();
httpStatusCode = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute).toInt();
decryptionResult =
tryDecryptResponseBody(encryptedResponseBody, replyError, encRequestData.key, encRequestData.iv, encRequestData.salt);
if (!sslErrors.isEmpty()
|| shouldBypassProxy(replyError, decryptionResult.decryptedBody, decryptionResult.isDecryptionSuccessful)) {
sslErrors = nestedSslErrors;
return false;
}
return true;
};
auto serviceType = apiPayload.value(apiDefs::key::serviceType).toString("");
auto userCountryCode = apiPayload.value(apiDefs::key::userCountryCode).toString("");
bypassProxy(endpoint, serviceType, userCountryCode, requestFunction, replyProcessingFunction);
}
responseBody = decryptionResult.decryptedBody;
const auto errorCode =
apiUtils::checkNetworkReplyErrors(sslErrors, replyErrorString, replyError, httpStatusCode, responseBody);
if (errorCode) {
return errorCode;
}
if (!decryptionResult.isDecryptionSuccessful) {
qCritical() << "error when decrypting the request body";
return ErrorCode::ApiConfigDecryptionError;
}
return ErrorCode::NoError;
}
QFuture<QPair<ErrorCode, QByteArray>> GatewayController::postAsync(const QString &endpoint, const QJsonObject apiPayload)
{
auto promise = QSharedPointer<QPromise<QPair<ErrorCode, QByteArray>>>::create();
promise->start();
EncryptedRequestData encRequestData = prepareRequest(endpoint, apiPayload);
if (encRequestData.errorCode != ErrorCode::NoError) {
promise->addResult(qMakePair(encRequestData.errorCode, QByteArray()));
promise->finish();
return promise->future();
}
QNetworkReply *reply = amnApp->networkManager()->post(encRequestData.request, encRequestData.requestBody);
auto sslErrors = QSharedPointer<QList<QSslError>>::create();
connect(reply, &QNetworkReply::sslErrors, [sslErrors](const QList<QSslError> &errors) { *sslErrors = errors; });
connect(reply, &QNetworkReply::finished, this, [promise, sslErrors, encRequestData, endpoint, apiPayload, reply, this]() mutable {
QByteArray encryptedResponseBody = reply->readAll();
QString replyErrorString = reply->errorString();
auto replyError = reply->error();
int httpStatusCode = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute).toInt();
reply->deleteLater();
auto decryptionResult =
tryDecryptResponseBody(encryptedResponseBody, replyError, encRequestData.key, encRequestData.iv, encRequestData.salt);
auto processResponse = [promise, encRequestData](const GatewayController::DecryptionResult &decryptionResult,
const QList<QSslError> &sslErrors, QNetworkReply::NetworkError replyError,
const QString &replyErrorString, int httpStatusCode) {
auto errorCode = apiUtils::checkNetworkReplyErrors(sslErrors, replyErrorString, replyError, httpStatusCode,
decryptionResult.decryptedBody);
if (errorCode) {
promise->addResult(qMakePair(errorCode, decryptionResult.decryptedBody));
promise->finish();
return;
}
if (!decryptionResult.isDecryptionSuccessful) {
Utils::logException();
qCritical() << "error when decrypting the request body";
promise->addResult(qMakePair(ErrorCode::ApiConfigDecryptionError, QByteArray()));
promise->finish();
return;
}
promise->addResult(qMakePair(ErrorCode::NoError, decryptionResult.decryptedBody));
promise->finish();
};
if (sslErrors->isEmpty() && shouldBypassProxy(replyError, decryptionResult.decryptedBody, decryptionResult.isDecryptionSuccessful)) {
auto serviceType = apiPayload.value(apiDefs::key::serviceType).toString("");
auto userCountryCode = apiPayload.value(apiDefs::key::userCountryCode).toString("");
QStringList primaryBaseUrls;
QStringList fallbackBaseUrls;
if (m_isDevEnvironment) {
primaryBaseUrls = QString(DEV_S3_ENDPOINT).split(", ", Qt::SkipEmptyParts);
} else {
primaryBaseUrls = QString(PROD_S3_ENDPOINT).split(", ", Qt::SkipEmptyParts);
fallbackBaseUrls = QString(FALLBACK_S3_ENDPOINT).split(", ", Qt::SkipEmptyParts);
}
std::random_device randomDevice;
std::mt19937 generator(randomDevice());
std::shuffle(primaryBaseUrls.begin(), primaryBaseUrls.end(), generator);
std::shuffle(fallbackBaseUrls.begin(), fallbackBaseUrls.end(), generator);
auto appendStorageUrls = [&serviceType, &userCountryCode](const QStringList &baseUrls, QStringList &target) {
if (!serviceType.isEmpty()) {
for (const auto &baseUrl : baseUrls) {
QByteArray path = ("endpoints-" + serviceType + "-" + userCountryCode).toUtf8();
target.push_back(baseUrl + path.toBase64(QByteArray::Base64UrlEncoding | QByteArray::OmitTrailingEquals) + ".json");
}
}
for (const auto &baseUrl : baseUrls) {
target.push_back(baseUrl + "endpoints.json");
}
};
QStringList proxyStorageUrls;
appendStorageUrls(primaryBaseUrls, proxyStorageUrls);
appendStorageUrls(fallbackBaseUrls, proxyStorageUrls);
getProxyUrlsAsync(proxyStorageUrls, 0, [this, encRequestData, endpoint, processResponse](const QStringList &proxyUrls) {
getProxyUrlAsync(proxyUrls, 0, [this, encRequestData, endpoint, processResponse](const QString &proxyUrl) {
bypassProxyAsync(endpoint, proxyUrl, encRequestData,
[processResponse, this](const QByteArray &decryptedBody, bool isDecryptionSuccessful,
const QList<QSslError> &sslErrors, QNetworkReply::NetworkError replyError,
const QString &replyErrorString, int httpStatusCode) {
GatewayController::DecryptionResult result;
result.decryptedBody = decryptedBody;
result.isDecryptionSuccessful = isDecryptionSuccessful;
processResponse(result, sslErrors, replyError, replyErrorString, httpStatusCode);
});
});
});
} else {
processResponse(decryptionResult, *sslErrors, replyError, replyErrorString, httpStatusCode);
}
});
return promise->future();
}
QStringList GatewayController::getProxyUrls(const QString &serviceType, const QString &userCountryCode)
{
QNetworkRequest request;
request.setTransferTimeout(proxyStorageRequestTimeoutMsecs);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
QEventLoop wait;
QList<QSslError> sslErrors;
QNetworkReply *reply;
QStringList primaryBaseUrls;
QStringList fallbackBaseUrls;
if (m_isDevEnvironment) {
primaryBaseUrls = QString(DEV_S3_ENDPOINT).split(", ", Qt::SkipEmptyParts);
} else {
primaryBaseUrls = QString(PROD_S3_ENDPOINT).split(", ", Qt::SkipEmptyParts);
fallbackBaseUrls = QString(FALLBACK_S3_ENDPOINT).split(", ", Qt::SkipEmptyParts);
}
std::random_device randomDevice;
std::mt19937 generator(randomDevice());
std::shuffle(primaryBaseUrls.begin(), primaryBaseUrls.end(), generator);
std::shuffle(fallbackBaseUrls.begin(), fallbackBaseUrls.end(), generator);
QByteArray key = m_isDevEnvironment ? DEV_AGW_PUBLIC_KEY : PROD_AGW_PUBLIC_KEY;
auto appendStorageUrls = [&serviceType, &userCountryCode](const QStringList &baseUrls, QStringList &target) {
if (!serviceType.isEmpty()) {
for (const auto &baseUrl : baseUrls) {
QByteArray path = ("endpoints-" + serviceType + "-" + userCountryCode).toUtf8();
target.push_back(baseUrl + path.toBase64(QByteArray::Base64UrlEncoding | QByteArray::OmitTrailingEquals) + ".json");
}
}
for (const auto &baseUrl : baseUrls) {
target.push_back(baseUrl + "endpoints.json");
}
};
QStringList proxyStorageUrls;
appendStorageUrls(primaryBaseUrls, proxyStorageUrls);
appendStorageUrls(fallbackBaseUrls, proxyStorageUrls);
if (proxyStorageUrls.empty()) {
qDebug() << "empty storage endpoint list";
return {};
}
for (const auto &proxyStorageUrl : proxyStorageUrls) {
request.setUrl(proxyStorageUrl);
reply = amnApp->networkManager()->get(request);
connect(reply, &QNetworkReply::finished, &wait, &QEventLoop::quit);
connect(reply, &QNetworkReply::sslErrors, [this, &sslErrors](const QList<QSslError> &errors) { sslErrors = errors; });
wait.exec(QEventLoop::ExcludeUserInputEvents);
if (reply->error() == QNetworkReply::NetworkError::NoError) {
auto encryptedResponseBody = reply->readAll();
reply->deleteLater();
EVP_PKEY *privateKey = nullptr;
QByteArray responseBody;
try {
if (!m_isDevEnvironment) {
QCryptographicHash hash(QCryptographicHash::Sha512);
hash.addData(key);
QByteArray hashResult = hash.result().toHex();
QByteArray key = QByteArray::fromHex(hashResult.left(64));
QByteArray iv = QByteArray::fromHex(hashResult.mid(64, 32));
QByteArray ba = QByteArray::fromBase64(encryptedResponseBody);
QSimpleCrypto::QBlockCipher blockCipher;
responseBody = blockCipher.decryptAesBlockCipher(ba, key, iv);
} else {
responseBody = encryptedResponseBody;
}
} catch (...) {
Utils::logException();
qCritical() << "error loading private key from environment variables or decrypting payload" << encryptedResponseBody;
continue;
}
auto endpointsArray = QJsonDocument::fromJson(responseBody).array();
QStringList endpoints;
for (const auto &endpoint : endpointsArray) {
endpoints.push_back(endpoint.toString());
}
return endpoints;
} else {
auto replyError = reply->error();
int httpStatusCode = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute).toInt();
qDebug() << replyError;
qDebug() << httpStatusCode;
qDebug() << "go to the next storage endpoint";
reply->deleteLater();
}
}
return {};
}
bool GatewayController::shouldBypassProxy(const QNetworkReply::NetworkError &replyError, const QByteArray &decryptedResponseBody,
bool isDecryptionSuccessful)
{
const QByteArray &responseBody = decryptedResponseBody;
int apiHttpStatus = -1;
QString apiErrorMessage;
if (isDecryptionSuccessful) {
QJsonDocument jsonDoc = QJsonDocument::fromJson(responseBody);
if (jsonDoc.isObject()) {
QJsonObject jsonObj = jsonDoc.object();
apiHttpStatus = jsonObj.value("http_status").toInt(-1);
apiErrorMessage = jsonObj.value(QStringLiteral("message")).toString().trimmed();
}
} else {
qDebug() << "failed to decrypt the data";
return true;
}
if (replyError == QNetworkReply::NetworkError::OperationCanceledError || replyError == QNetworkReply::NetworkError::TimeoutError) {
qDebug() << "timeout occurred";
qDebug() << replyError;
return true;
}
if (responseBody.contains("html")) {
qDebug() << "the response contains an html tag";
return true;
}
if (apiHttpStatus == httpStatusCodeRequestTimeout) {
return false;
}
if (apiHttpStatus == httpStatusCodeNotFound) {
if (responseBody.contains(errorResponsePattern1) || responseBody.contains(errorResponsePattern2)
|| responseBody.contains(errorResponsePattern3) || responseBody.contains(errorResponsePatternQrSessionNotFound)
|| responseBody.contains(errorResponsePatternSessionNotFound)) {
return false;
} else {
qDebug() << replyError;
return true;
}
}
if (apiHttpStatus == httpStatusCodeNotImplemented) {
if (responseBody.contains(updateRequestResponsePattern)) {
return false;
} else {
qDebug() << replyError;
return true;
}
}
if (apiHttpStatus == httpStatusCodeConflict) {
return false;
}
if (apiHttpStatus == httpStatusCodePaymentRequired) {
return false;
}
if (apiHttpStatus == httpStatusCodeUnprocessableEntity) {
return apiErrorMessage != unprocessableSubscriptionMessage;
}
if (replyError != QNetworkReply::NetworkError::NoError) {
qDebug() << replyError;
return true;
}
return false;
}
void GatewayController::bypassProxy(const QString &endpoint, const QString &serviceType, const QString &userCountryCode,
std::function<QNetworkReply *(const QString &url)> requestFunction,
std::function<bool(QNetworkReply *reply, const QList<QSslError> &sslErrors)> replyProcessingFunction)
{
QStringList proxyUrls = getProxyUrls(serviceType, userCountryCode);
std::random_device randomDevice;
std::mt19937 generator(randomDevice());
std::shuffle(proxyUrls.begin(), proxyUrls.end(), generator);
QByteArray responseBody;
auto bypassFunction = [this](const QString &endpoint, const QString &proxyUrl,
std::function<QNetworkReply *(const QString &url)> requestFunction,
std::function<bool(QNetworkReply * reply, const QList<QSslError> &sslErrors)> replyProcessingFunction) {
QEventLoop wait;
QList<QSslError> sslErrors;
qDebug() << "go to the next proxy endpoint";
QNetworkReply *reply = requestFunction(endpoint.arg(proxyUrl));
QObject::connect(reply, &QNetworkReply::finished, &wait, &QEventLoop::quit);
connect(reply, &QNetworkReply::sslErrors, [this, &sslErrors](const QList<QSslError> &errors) { sslErrors = errors; });
wait.exec(QEventLoop::ExcludeUserInputEvents);
auto result = replyProcessingFunction(reply, sslErrors);
reply->deleteLater();
return result;
};
if (m_proxyUrl.isEmpty()) {
QNetworkRequest request;
request.setTransferTimeout(1000);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
QEventLoop wait;
QList<QSslError> sslErrors;
QNetworkReply *reply;
for (const QString &proxyUrl : proxyUrls) {
request.setUrl(proxyUrl + "lmbd-health");
reply = amnApp->networkManager()->get(request);
connect(reply, &QNetworkReply::finished, &wait, &QEventLoop::quit);
connect(reply, &QNetworkReply::sslErrors, [this, &sslErrors](const QList<QSslError> &errors) { sslErrors = errors; });
wait.exec(QEventLoop::ExcludeUserInputEvents);
if (reply->error() == QNetworkReply::NetworkError::NoError) {
reply->deleteLater();
m_proxyUrl = proxyUrl;
if (!m_proxyUrl.isEmpty()) {
break;
}
} else {
reply->deleteLater();
}
}
}
if (!m_proxyUrl.isEmpty()) {
if (bypassFunction(endpoint, m_proxyUrl, requestFunction, replyProcessingFunction)) {
return;
}
}
for (const QString &proxyUrl : proxyUrls) {
if (bypassFunction(endpoint, proxyUrl, requestFunction, replyProcessingFunction)) {
m_proxyUrl = proxyUrl;
break;
}
}
}
void GatewayController::getProxyUrlsAsync(const QStringList proxyStorageUrls, const int currentProxyStorageIndex,
std::function<void(const QStringList &)> onComplete)
{
if (currentProxyStorageIndex >= proxyStorageUrls.size()) {
onComplete({});
return;
}
QNetworkRequest request;
request.setTransferTimeout(proxyStorageRequestTimeoutMsecs);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
request.setUrl(proxyStorageUrls[currentProxyStorageIndex]);
QNetworkReply *reply = amnApp->networkManager()->get(request);
// connect(reply, &QNetworkReply::sslErrors, this, [state](const QList<QSslError> &e) { *(state->sslErrors) = e; });
connect(reply, &QNetworkReply::finished, this, [this, proxyStorageUrls, currentProxyStorageIndex, onComplete, reply]() {
if (reply->error() == QNetworkReply::NoError) {
QByteArray encrypted = reply->readAll();
reply->deleteLater();
QByteArray responseBody;
try {
QByteArray key = m_isDevEnvironment ? DEV_AGW_PUBLIC_KEY : PROD_AGW_PUBLIC_KEY;
if (!m_isDevEnvironment) {
QCryptographicHash hash(QCryptographicHash::Sha512);
hash.addData(key);
QByteArray h = hash.result().toHex();
QByteArray decKey = QByteArray::fromHex(h.left(64));
QByteArray iv = QByteArray::fromHex(h.mid(64, 32));
QByteArray ba = QByteArray::fromBase64(encrypted);
QSimpleCrypto::QBlockCipher cipher;
responseBody = cipher.decryptAesBlockCipher(ba, decKey, iv);
} else {
responseBody = encrypted;
}
} catch (...) {
Utils::logException();
qCritical() << "error decrypting payload";
QMetaObject::invokeMethod(
this, [=]() { getProxyUrlsAsync(proxyStorageUrls, currentProxyStorageIndex + 1, onComplete); }, Qt::QueuedConnection);
return;
}
QJsonArray endpointsArray = QJsonDocument::fromJson(responseBody).array();
QStringList endpoints;
for (const QJsonValue &endpoint : endpointsArray)
endpoints.push_back(endpoint.toString());
QStringList shuffled = endpoints;
std::random_device randomDevice;
std::mt19937 generator(randomDevice());
std::shuffle(shuffled.begin(), shuffled.end(), generator);
onComplete(shuffled);
return;
}
int httpStatusCode = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute).toInt();
qDebug() << httpStatusCode;
qDebug() << "go to the next storage endpoint";
reply->deleteLater();
QMetaObject::invokeMethod(
this, [=]() { getProxyUrlsAsync(proxyStorageUrls, currentProxyStorageIndex + 1, onComplete); }, Qt::QueuedConnection);
});
}
void GatewayController::getProxyUrlAsync(const QStringList proxyUrls, const int currentProxyIndex,
std::function<void(const QString &)> onComplete)
{
if (currentProxyIndex >= proxyUrls.size()) {
onComplete("");
return;
}
QNetworkRequest request;
request.setTransferTimeout(1000);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
request.setUrl(proxyUrls[currentProxyIndex] + "lmbd-health");
QNetworkReply *reply = amnApp->networkManager()->get(request);
// connect(reply, &QNetworkReply::sslErrors, this, [state](const QList<QSslError> &e) {
// *(state->sslErrors) = e;
// });
connect(reply, &QNetworkReply::finished, this, [this, proxyUrls, currentProxyIndex, onComplete, reply]() {
reply->deleteLater();
if (reply->error() == QNetworkReply::NoError) {
m_proxyUrl = proxyUrls[currentProxyIndex];
onComplete(m_proxyUrl);
return;
}
qDebug() << "go to the next proxy endpoint";
QMetaObject::invokeMethod(this, [=]() { getProxyUrlAsync(proxyUrls, currentProxyIndex + 1, onComplete); }, Qt::QueuedConnection);
});
}
void GatewayController::bypassProxyAsync(
const QString &endpoint, const QString &proxyUrl, EncryptedRequestData encRequestData,
std::function<void(const QByteArray &, bool, const QList<QSslError> &, QNetworkReply::NetworkError, const QString &, int)> onComplete)
{
auto sslErrors = QSharedPointer<QList<QSslError>>::create();
if (proxyUrl.isEmpty()) {
onComplete(QByteArray(), false, *sslErrors, QNetworkReply::InternalServerError, "empty proxy url", 0);
return;
}
QNetworkRequest request = encRequestData.request;
request.setUrl(endpoint.arg(proxyUrl));
QNetworkReply *reply = amnApp->networkManager()->post(request, encRequestData.requestBody);
connect(reply, &QNetworkReply::sslErrors, this, [sslErrors](const QList<QSslError> &errors) { *sslErrors = errors; });
connect(reply, &QNetworkReply::finished, this, [sslErrors, onComplete, encRequestData, reply, this]() {
QByteArray encryptedResponseBody = reply->readAll();
QString replyErrorString = reply->errorString();
auto replyError = reply->error();
int httpStatusCode = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute).toInt();
reply->deleteLater();
auto decryptionResult =
tryDecryptResponseBody(encryptedResponseBody, replyError, encRequestData.key, encRequestData.iv, encRequestData.salt);
onComplete(decryptionResult.decryptedBody, decryptionResult.isDecryptionSuccessful, *sslErrors, replyError, replyErrorString,
httpStatusCode);
});
}

View File

@@ -1,72 +0,0 @@
#ifndef GATEWAYCONTROLLER_H
#define GATEWAYCONTROLLER_H
#include <QFuture>
#include <QNetworkReply>
#include <QObject>
#include <QPair>
#include <QPromise>
#include <QSharedPointer>
#include "core/utils/errorCodes.h"
#include "core/utils/routeModes.h"
#include "core/utils/commonStructs.h"
#ifdef Q_OS_IOS
#include "platforms/ios/ios_controller.h"
#endif
class GatewayController : public QObject
{
Q_OBJECT
public:
explicit GatewayController(const QString &gatewayEndpoint, const bool isDevEnvironment, const int requestTimeoutMsecs,
const bool isStrictKillSwitchEnabled, QObject *parent = nullptr);
amnezia::ErrorCode post(const QString &endpoint, const QJsonObject apiPayload, QByteArray &responseBody);
QFuture<QPair<amnezia::ErrorCode, QByteArray>> postAsync(const QString &endpoint, const QJsonObject apiPayload);
private:
struct EncryptedRequestData
{
QNetworkRequest request;
QByteArray requestBody;
QByteArray key;
QByteArray iv;
QByteArray salt;
amnezia::ErrorCode errorCode;
};
struct DecryptionResult
{
QByteArray decryptedBody;
bool isDecryptionSuccessful;
};
EncryptedRequestData prepareRequest(const QString &endpoint, const QJsonObject &apiPayload);
DecryptionResult tryDecryptResponseBody(const QByteArray &encryptedResponseBody, QNetworkReply::NetworkError replyError,
const QByteArray &key, const QByteArray &iv, const QByteArray &salt);
QStringList getProxyUrls(const QString &serviceType, const QString &userCountryCode);
bool shouldBypassProxy(const QNetworkReply::NetworkError &replyError, const QByteArray &decryptedResponseBody, bool isDecryptionSuccessful);
void bypassProxy(const QString &endpoint, const QString &serviceType, const QString &userCountryCode,
std::function<QNetworkReply *(const QString &url)> requestFunction,
std::function<bool(QNetworkReply *reply, const QList<QSslError> &sslErrors)> replyProcessingFunction);
void getProxyUrlsAsync(const QStringList proxyStorageUrls, const int currentProxyStorageIndex,
std::function<void(const QStringList &)> onComplete);
void getProxyUrlAsync(const QStringList proxyUrls, const int currentProxyIndex, std::function<void(const QString &)> onComplete);
void bypassProxyAsync(
const QString &endpoint, const QString &proxyUrl, EncryptedRequestData encRequestData,
std::function<void(const QByteArray &, bool, const QList<QSslError> &, QNetworkReply::NetworkError, const QString &, int)> onComplete);
int m_requestTimeoutMsecs;
QString m_gatewayEndpoint;
bool m_isDevEnvironment = false;
bool m_isStrictKillSwitchEnabled = false;
inline static QString m_proxyUrl;
};
#endif // GATEWAYCONTROLLER_H

View File

@@ -0,0 +1,208 @@
#include "gatewayControllerAdapter.h"
#include <map>
#include <mutex>
#include <string>
#include <vector>
#include <QDebug>
#include <QEventLoop>
#include <QJsonDocument>
#include <QPointer>
#include <QPromise>
#include <QSharedPointer>
#include <QStringList>
#include <QThread>
#include <agw/gateway_controller.h>
#include <agw/config.h>
#include <agw/types.h>
#include "core/utils/constants/apiKeys.h"
#include "core/utils/networkUtilities.h"
#include "embedded_agw_public_keys.h"
#ifdef Q_OS_IOS
#include "platforms/ios/ios_controller.h"
#endif
#ifdef AMNEZIA_DESKTOP
#include "core/utils/ipcClient.h"
#endif
namespace
{
amnezia::ErrorCode mapError(agw::ErrorCode error)
{
if (error == agw::ErrorCode::Cancelled) {
return amnezia::ErrorCode::ApiConfigTimeoutError;
}
return static_cast<amnezia::ErrorCode>(static_cast<int>(error));
}
std::vector<std::string> splitCsv(const QString &value)
{
std::vector<std::string> out;
const QStringList parts = value.split(", ", Qt::SkipEmptyParts);
for (const QString &p : parts) {
out.push_back(p.toStdString());
}
return out;
}
agw::Config makeConfig(const QString &gatewayEndpoint, bool isDevEnvironment, int requestTimeoutMsecs,
bool isStrictKillSwitchEnabled)
{
agw::Config cfg;
cfg.gatewayEndpoint = gatewayEndpoint.toStdString();
const QByteArray pem = isDevEnvironment ? DEV_AGW_PUBLIC_KEY : PROD_AGW_PUBLIC_KEY;
cfg.agwPublicKeyPem = std::string(pem.constData(), static_cast<std::size_t>(pem.size()));
if (isDevEnvironment) {
cfg.s3PrimaryEndpoints = splitCsv(QString(DEV_S3_ENDPOINT));
} else {
cfg.s3PrimaryEndpoints = splitCsv(QString(PROD_S3_ENDPOINT));
cfg.s3FallbackEndpoints = splitCsv(QString(FALLBACK_S3_ENDPOINT));
}
cfg.isDevEnvironment = isDevEnvironment;
cfg.requestTimeoutMsecs = requestTimeoutMsecs;
cfg.log = [](agw::LogLevel level, const std::string &message) {
const QString msg = QString::fromStdString(message);
switch (level) {
case agw::LogLevel::Error: qWarning() << "[agw]" << msg; break;
case agw::LogLevel::Warning: qWarning() << "[agw]" << msg; break;
default: qDebug() << "[agw]" << msg; break;
}
};
cfg.onBeforeRequest = [isStrictKillSwitchEnabled](const std::string &hostStd) {
const QString host = QString::fromStdString(hostStd);
(void)host;
(void)isStrictKillSwitchEnabled;
#ifdef Q_OS_IOS
IosController::Instance()->requestInetAccess();
QThread::msleep(10);
#endif
#ifdef AMNEZIA_DESKTOP
if (isStrictKillSwitchEnabled) {
const QString ip = NetworkUtilities::getIPAddress(host);
if (!ip.isEmpty()) {
IpcClient::withInterface([&](QSharedPointer<IpcInterfaceReplica> iface) {
QRemoteObjectPendingReply<bool> reply = iface->addKillSwitchAllowedRange(QStringList { ip });
if (!reply.waitForFinished(1000) || !reply.returnValue()) {
qWarning() << "GatewayControllerAdapter: addKillSwitchAllowedRange failed";
}
});
}
}
#endif
};
return cfg;
}
std::shared_ptr<agw::GatewayController> getClientForEnv(const QString &gatewayEndpoint, bool isDevEnvironment,
int requestTimeoutMsecs, bool isStrictKillSwitchEnabled)
{
static std::mutex mutex;
static std::map<std::string, std::shared_ptr<agw::GatewayController>> clients;
const std::string key = gatewayEndpoint.toStdString() + "|" + (isDevEnvironment ? "1" : "0") + "|"
+ std::to_string(requestTimeoutMsecs) + "|" + (isStrictKillSwitchEnabled ? "1" : "0");
std::lock_guard<std::mutex> lock(mutex);
auto it = clients.find(key);
if (it != clients.end()) {
return it->second;
}
auto client = std::make_shared<agw::GatewayController>(
makeConfig(gatewayEndpoint, isDevEnvironment, requestTimeoutMsecs, isStrictKillSwitchEnabled));
clients.emplace(key, client);
return client;
}
}
GatewayControllerAdapter::GatewayControllerAdapter(const QString &gatewayEndpoint, const bool isDevEnvironment, const int requestTimeoutMsecs,
const bool isStrictKillSwitchEnabled, QObject *parent)
: QObject(parent),
m_controller(getClientForEnv(gatewayEndpoint, isDevEnvironment, requestTimeoutMsecs, isStrictKillSwitchEnabled))
{
}
amnezia::ErrorCode GatewayControllerAdapter::post(const QString &endpoint, const QJsonObject apiPayload, QByteArray &responseBody)
{
const std::string payload = QJsonDocument(apiPayload).toJson().toStdString();
const std::string serviceType = apiPayload.value(apiDefs::key::serviceType).toString().toStdString();
const std::string userCountryCode = apiPayload.value(apiDefs::key::userCountryCode).toString().toStdString();
qInfo().noquote() << "[agw-adapter] post (sync) endpoint=" << endpoint
<< "payloadLen=" << payload.size() << "thread=" << QThread::currentThread();
QEventLoop loop;
QObject context;
agw::Response result;
m_controller->postAsync(
endpoint.toStdString(), payload,
[&loop, &context, &result](agw::Response r) {
QMetaObject::invokeMethod(
&context,
[&loop, &result, r]() {
result = r;
loop.quit();
},
Qt::QueuedConnection);
},
agw::FailoverContext { serviceType, userCountryCode });
loop.exec(QEventLoop::ExcludeUserInputEvents);
responseBody = QByteArray::fromStdString(result.body);
const amnezia::ErrorCode ec = mapError(result.error);
qInfo().noquote() << "[agw-adapter] post (sync) result errorCode=" << static_cast<int>(ec)
<< "bodyLen=" << responseBody.size();
return ec;
}
QFuture<QPair<amnezia::ErrorCode, QByteArray>> GatewayControllerAdapter::postAsync(const QString &endpoint, const QJsonObject apiPayload)
{
auto promise = QSharedPointer<QPromise<QPair<amnezia::ErrorCode, QByteArray>>>::create();
promise->start();
QFuture<QPair<amnezia::ErrorCode, QByteArray>> future = promise->future();
const std::string payload = QJsonDocument(apiPayload).toJson().toStdString();
const std::string serviceType = apiPayload.value(apiDefs::key::serviceType).toString().toStdString();
const std::string userCountryCode = apiPayload.value(apiDefs::key::userCountryCode).toString().toStdString();
QPointer<GatewayControllerAdapter> self(this);
qInfo().noquote() << "[agw-adapter] postAsync endpoint=" << endpoint
<< "payloadLen=" << payload.size() << "callerThread=" << QThread::currentThread();
m_controller->postAsync(
endpoint.toStdString(), payload,
[promise, self](agw::Response r) {
const amnezia::ErrorCode ec = mapError(r.error);
const QByteArray body = QByteArray::fromStdString(r.body);
qInfo().noquote() << "[agw-adapter] postAsync SDK callback errorCode=" << static_cast<int>(ec)
<< "bodyLen=" << body.size() << "poolThread=" << QThread::currentThread()
<< "→ marshalling to object thread";
auto deliver = [promise, ec, body]() {
promise->addResult(qMakePair(ec, body));
promise->finish();
};
if (self) {
QMetaObject::invokeMethod(self.data(), deliver, Qt::QueuedConnection);
} else {
deliver();
}
},
agw::FailoverContext { serviceType, userCountryCode });
return future;
}

View File

@@ -0,0 +1,36 @@
#ifndef GATEWAYCONTROLLERADAPTER_H
#define GATEWAYCONTROLLERADAPTER_H
#include <memory>
#include <QByteArray>
#include <QFuture>
#include <QJsonObject>
#include <QObject>
#include <QPair>
#include <QString>
#include "core/utils/errorCodes.h"
namespace agw
{
class GatewayController;
}
class GatewayControllerAdapter : public QObject
{
Q_OBJECT
public:
explicit GatewayControllerAdapter(const QString &gatewayEndpoint, const bool isDevEnvironment, const int requestTimeoutMsecs,
const bool isStrictKillSwitchEnabled, QObject *parent = nullptr);
amnezia::ErrorCode post(const QString &endpoint, const QJsonObject apiPayload, QByteArray &responseBody);
QFuture<QPair<amnezia::ErrorCode, QByteArray>> postAsync(const QString &endpoint, const QJsonObject apiPayload);
private:
std::shared_ptr<agw::GatewayController> m_controller;
};
#endif

View File

@@ -11,7 +11,7 @@
#include "amneziaApplication.h"
#include "logger.h"
#include "version.h"
#include "core/controllers/gatewayController.h"
#include "core/controllers/gatewayControllerAdapter.h"
#include "core/utils/constants/apiKeys.h"
#include "core/utils/selfhosted/scriptsRegistry.h"
@@ -92,7 +92,7 @@ void UpdateController::doGetAsync(const QString &endpoint, std::function<void(bo
void UpdateController::fetchGatewayUrl()
{
auto gatewayController = QSharedPointer<GatewayController>::create(m_appSettingsRepository->getGatewayEndpoint(),
auto gatewayController = QSharedPointer<GatewayControllerAdapter>::create(m_appSettingsRepository->getGatewayEndpoint(),
m_appSettingsRepository->isDevGatewayEnv(),
7000,
m_appSettingsRepository->isStrictKillSwitchEnabled());

View File

@@ -5,10 +5,14 @@ class AmneziaVPN(ConanFile):
generators = "VirtualBuildEnv", "CMakeConfigDeps"
options = {
"macos_ne": [True, False]
"macos_ne": [True, False],
# True (по умолчанию): SDK собирается из исходников в дереве клиента (отлаживается) —
# тянем его зависимости. False: потребляем готовый Conan-пакет agw-sdk.
"agw_sdk_from_source": [True, False]
}
default_options = {
"macos_ne": False
"macos_ne": False,
"agw_sdk_from_source": True
}
def requirements(self):
@@ -45,3 +49,11 @@ class AmneziaVPN(ConanFile):
self.requires("libssh/0.11.3@amnezia")
self.requires("openssl/3.6.2")
self.requires("zlib/1.3.2")
# AGW SDK — транспорт к API-шлюзу (Qt-free, общий OpenSSL/3.6.2).
if self.options.agw_sdk_from_source:
# Собираем SDK из исходников (agw-sdk/ через add_subdirectory) — нужны его зависимости.
self.requires("libcurl/8.10.1")
self.requires("nlohmann_json/3.11.3")
else:
self.requires("agw-sdk/0.1.0")