diff --git a/config/prism.php b/config/prism.php index 046b6781f..03dc2d91d 100644 --- a/config/prism.php +++ b/config/prism.php @@ -60,5 +60,11 @@ 'x_title' => env('OPENROUTER_SITE_X_TITLE', null), ], ], + 'azure' => [ + 'url' => env('AZURE_AI_URL', ''), + 'api_key' => env('AZURE_AI_API_KEY', ''), + 'api_version' => env('AZURE_AI_API_VERSION', '2024-10-21'), + 'deployment_name' => env('AZURE_AI_DEPLOYMENT', ''), + ], ], ]; diff --git a/src/Enums/Provider.php b/src/Enums/Provider.php index 21c7a17f6..8985de35c 100644 --- a/src/Enums/Provider.php +++ b/src/Enums/Provider.php @@ -17,4 +17,5 @@ enum Provider: string case Gemini = 'gemini'; case VoyageAI = 'voyageai'; case ElevenLabs = 'elevenlabs'; + case Azure = 'azure'; } diff --git a/src/PrismManager.php b/src/PrismManager.php index 035a21f99..66735e8e8 100644 --- a/src/PrismManager.php +++ b/src/PrismManager.php @@ -9,6 +9,7 @@ use InvalidArgumentException; use Prism\Prism\Enums\Provider as ProviderEnum; use Prism\Prism\Providers\Anthropic\Anthropic; +use Prism\Prism\Providers\Azure\Azure; use Prism\Prism\Providers\DeepSeek\DeepSeek; use Prism\Prism\Providers\ElevenLabs\ElevenLabs; use Prism\Prism\Providers\Gemini\Gemini; @@ -215,6 +216,19 @@ protected function createOpenrouterProvider(array $config): OpenRouter ); } + /** + * @param array $config + */ + protected function createAzureProvider(array $config): Azure + { + return new Azure( + url: $config['url'] ?? '', + apiKey: $config['api_key'] ?? '', + apiVersion: $config['api_version'] ?? '2024-10-21', + deploymentName: $config['deployment_name'] ?? null, + ); + } + /** * @param array $config */ diff --git a/src/Providers/Azure/Azure.php b/src/Providers/Azure/Azure.php new file mode 100644 index 000000000..3180effbd --- /dev/null +++ b/src/Providers/Azure/Azure.php @@ -0,0 +1,145 @@ +client( + $request->clientOptions(), + $request->clientRetry(), + $this->buildUrl($request->model()) + )); + + return $handler->handle($request); + } + + #[\Override] + public function structured(StructuredRequest $request): StructuredResponse + { + $handler = new Structured($this->client( + $request->clientOptions(), + $request->clientRetry(), + $this->buildUrl($request->model()) + )); + + return $handler->handle($request); + } + + #[\Override] + public function embeddings(EmbeddingsRequest $request): EmbeddingsResponse + { + $handler = new Embeddings($this->client( + $request->clientOptions(), + $request->clientRetry(), + $this->buildUrl($request->model()) + )); + + return $handler->handle($request); + } + + #[\Override] + public function stream(TextRequest $request): Generator + { + $handler = new Stream($this->client( + $request->clientOptions(), + $request->clientRetry(), + $this->buildUrl($request->model()) + )); + + return $handler->handle($request); + } + + public function handleRequestException(string $model, RequestException $e): never + { + $statusCode = $e->response->getStatusCode(); + $responseData = $e->response->json(); + $errorMessage = data_get($responseData, 'error.message', 'Unknown error'); + + match ($statusCode) { + 429 => throw PrismRateLimitedException::make( + rateLimits: [], + retryAfter: (int) $e->response->header('retry-after') + ), + 413 => throw PrismRequestTooLargeException::make(ProviderName::Azure), + 400, 404 => throw PrismException::providerResponseError( + sprintf('Azure Error: %s', $errorMessage) + ), + default => throw PrismException::providerRequestError($model, $e), + }; + } + + /** + * Build the URL for the Azure deployment. + * Supports both Azure OpenAI and Azure AI Model Inference endpoints. + */ + protected function buildUrl(string $model): string + { + // If URL already contains the full path, use it directly + if (str_contains($this->url, '/chat/completions') || str_contains($this->url, '/embeddings')) { + return $this->url; + } + + // Use deployment name from config, or model name as fallback + $deployment = $this->deploymentName ?: $model; + + // If URL already contains 'deployments', append the deployment + if (str_contains($this->url, '/openai/deployments')) { + return rtrim($this->url, '/')."/{$deployment}"; + } + + // Build standard Azure OpenAI URL pattern + return rtrim($this->url, '/')."/openai/deployments/{$deployment}"; + } + + /** + * @param array $options + * @param array $retry + */ + protected function client(array $options = [], array $retry = [], ?string $baseUrl = null): PendingRequest + { + return $this->baseClient() + ->withHeaders([ + 'api-key' => $this->apiKey, + ]) + ->withQueryParameters([ + 'api-version' => $this->apiVersion, + ]) + ->withOptions($options) + ->when($retry !== [], fn ($client) => $client->retry(...$retry)) + ->baseUrl($baseUrl ?? $this->url); + } +} diff --git a/src/Providers/Azure/Concerns/MapsFinishReason.php b/src/Providers/Azure/Concerns/MapsFinishReason.php new file mode 100644 index 000000000..1a774baaa --- /dev/null +++ b/src/Providers/Azure/Concerns/MapsFinishReason.php @@ -0,0 +1,19 @@ + $data + */ + protected function mapFinishReason(array $data): FinishReason + { + return FinishReasonMap::map(data_get($data, 'choices.0.finish_reason', '')); + } +} diff --git a/src/Providers/Azure/Concerns/ValidatesResponses.php b/src/Providers/Azure/Concerns/ValidatesResponses.php new file mode 100644 index 000000000..56e95ee47 --- /dev/null +++ b/src/Providers/Azure/Concerns/ValidatesResponses.php @@ -0,0 +1,26 @@ + $data + */ + protected function validateResponse(array $data): void + { + if ($data === []) { + throw PrismException::providerResponseError('Azure Error: Empty response'); + } + + if (data_get($data, 'error')) { + throw PrismException::providerResponseError( + sprintf('Azure Error: %s', data_get($data, 'error.message', 'Unknown error')) + ); + } + } +} diff --git a/src/Providers/Azure/Handlers/Embeddings.php b/src/Providers/Azure/Handlers/Embeddings.php new file mode 100644 index 000000000..c53b31e51 --- /dev/null +++ b/src/Providers/Azure/Handlers/Embeddings.php @@ -0,0 +1,64 @@ +sendRequest($request); + + $this->validateResponse($response); + + $data = $response->json(); + + return new EmbeddingsResponse( + embeddings: array_map(fn (array $item): Embedding => Embedding::fromArray($item['embedding']), data_get($data, 'data', [])), + usage: new EmbeddingsUsage(data_get($data, 'usage.total_tokens')), + meta: new Meta( + id: '', + model: data_get($data, 'model', ''), + ), + ); + } + + protected function sendRequest(Request $request): Response + { + /** @var Response $response */ + $response = $this->client->post( + 'embeddings', + [ + 'input' => $request->inputs(), + ...($request->providerOptions() ?? []), + ] + ); + + return $response; + } + + protected function validateResponse(Response $response): void + { + if ($response->json() === null) { + throw PrismException::providerResponseError('Azure Error: Empty embeddings response'); + } + + if ($response->json('error')) { + throw PrismException::providerResponseError( + sprintf('Azure Error: %s', data_get($response->json(), 'error.message', 'Unknown error')) + ); + } + } +} diff --git a/src/Providers/Azure/Handlers/Stream.php b/src/Providers/Azure/Handlers/Stream.php new file mode 100644 index 000000000..6ea076e02 --- /dev/null +++ b/src/Providers/Azure/Handlers/Stream.php @@ -0,0 +1,376 @@ +state = new StreamState; + } + + /** + * @return Generator + */ + public function handle(Request $request): Generator + { + $response = $this->sendRequest($request); + + yield from $this->processStream($response, $request); + } + + /** + * @return Generator + */ + protected function processStream(Response $response, Request $request, int $depth = 0): Generator + { + if ($depth >= $request->maxSteps()) { + throw new PrismException('Maximum tool call chain depth exceeded'); + } + + if ($depth === 0) { + $this->state->reset(); + } + + $text = ''; + $toolCalls = []; + + while (! $response->getBody()->eof()) { + $data = $this->parseNextDataLine($response->getBody()); + + if ($data === null) { + continue; + } + + if ($this->state->shouldEmitStreamStart()) { + $this->state->withMessageId(EventID::generate())->markStreamStarted(); + + yield new StreamStartEvent( + id: EventID::generate(), + timestamp: time(), + model: $request->model(), + provider: 'azure' + ); + } + + if ($this->hasToolCalls($data)) { + $toolCalls = $this->extractToolCalls($data, $toolCalls); + + $rawFinishReason = data_get($data, 'choices.0.finish_reason'); + if ($rawFinishReason === 'tool_calls') { + if ($this->state->hasTextStarted() && $text !== '') { + yield new TextCompleteEvent( + id: EventID::generate(), + timestamp: time(), + messageId: $this->state->messageId() + ); + } + + yield from $this->handleToolCalls($request, $text, $toolCalls, $depth); + + return; + } + + continue; + } + + $content = $this->extractContentDelta($data); + if ($content !== '' && $content !== '0') { + if ($this->state->shouldEmitTextStart()) { + $this->state->markTextStarted(); + + yield new TextStartEvent( + id: EventID::generate(), + timestamp: time(), + messageId: $this->state->messageId() + ); + } + + $text .= $content; + + yield new TextDeltaEvent( + id: EventID::generate(), + timestamp: time(), + delta: $content, + messageId: $this->state->messageId() + ); + + continue; + } + + $rawFinishReason = data_get($data, 'choices.0.finish_reason'); + if ($rawFinishReason !== null) { + $finishReason = $this->mapFinishReason($data); + + if ($this->state->hasTextStarted() && $text !== '') { + yield new TextCompleteEvent( + id: EventID::generate(), + timestamp: time(), + messageId: $this->state->messageId() + ); + } + + $this->state->withFinishReason($finishReason); + + $usage = $this->extractUsage($data); + if ($usage instanceof Usage) { + $this->state->addUsage($usage); + } + } + } + + if ($toolCalls !== []) { + yield from $this->handleToolCalls($request, $text, $toolCalls, $depth); + + return; + } + + yield new StreamEndEvent( + id: EventID::generate(), + timestamp: time(), + finishReason: $this->state->finishReason() ?? FinishReason::Stop, + usage: $this->state->usage() + ); + } + + /** + * @return array|null + * + * @throws PrismStreamDecodeException + */ + protected function parseNextDataLine(StreamInterface $stream): ?array + { + $line = $this->readLine($stream); + + if (! str_starts_with($line, 'data:')) { + return null; + } + + $line = trim(substr($line, strlen('data: '))); + + if (Str::contains($line, '[DONE]')) { + return null; + } + + try { + return json_decode($line, true, flags: JSON_THROW_ON_ERROR); + } catch (Throwable $e) { + throw new PrismStreamDecodeException('Azure', $e); + } + } + + /** + * @param array $data + */ + protected function hasToolCalls(array $data): bool + { + return ! empty(data_get($data, 'choices.0.delta.tool_calls', [])); + } + + /** + * @param array $data + * @param array> $toolCalls + * @return array> + */ + protected function extractToolCalls(array $data, array $toolCalls): array + { + $deltaToolCalls = data_get($data, 'choices.0.delta.tool_calls', []); + + foreach ($deltaToolCalls as $deltaToolCall) { + $index = data_get($deltaToolCall, 'index', 0); + + if (! isset($toolCalls[$index])) { + $toolCalls[$index] = [ + 'id' => '', + 'name' => '', + 'arguments' => '', + ]; + } + + if ($id = data_get($deltaToolCall, 'id')) { + $toolCalls[$index]['id'] = $id; + } + + if ($name = data_get($deltaToolCall, 'function.name')) { + $toolCalls[$index]['name'] = $name; + } + + if ($arguments = data_get($deltaToolCall, 'function.arguments')) { + $toolCalls[$index]['arguments'] .= $arguments; + } + } + + return $toolCalls; + } + + /** + * @param array $data + */ + protected function extractContentDelta(array $data): string + { + return data_get($data, 'choices.0.delta.content') ?? ''; + } + + /** + * @param array $data + */ + protected function extractUsage(array $data): ?Usage + { + $usage = data_get($data, 'usage'); + + if (! $usage) { + return null; + } + + return new Usage( + promptTokens: (int) data_get($usage, 'prompt_tokens', 0), + completionTokens: (int) data_get($usage, 'completion_tokens', 0) + ); + } + + /** + * @param array> $toolCalls + * @return Generator + */ + protected function handleToolCalls(Request $request, string $text, array $toolCalls, int $depth): Generator + { + $mappedToolCalls = $this->mapToolCalls($toolCalls); + + foreach ($mappedToolCalls as $toolCall) { + yield new ToolCallEvent( + id: EventID::generate(), + timestamp: time(), + toolCall: $toolCall, + messageId: $this->state->messageId() + ); + } + + $toolResults = $this->callTools($request->tools(), $mappedToolCalls); + + foreach ($toolResults as $result) { + yield new ToolResultEvent( + id: EventID::generate(), + timestamp: time(), + toolResult: $result, + messageId: $this->state->messageId() + ); + + foreach ($result->artifacts as $artifact) { + yield new ArtifactEvent( + id: EventID::generate(), + timestamp: time(), + artifact: $artifact, + toolCallId: $result->toolCallId, + toolName: $result->toolName, + messageId: $this->state->messageId(), + ); + } + } + + $request->addMessage(new AssistantMessage($text, $mappedToolCalls)); + $request->addMessage(new ToolResultMessage($toolResults)); + + $this->state->resetTextState(); + $this->state->withMessageId(EventID::generate()); + + $nextResponse = $this->sendRequest($request); + yield from $this->processStream($nextResponse, $request, $depth + 1); + } + + /** + * @param array> $toolCalls + * @return array + */ + protected function mapToolCalls(array $toolCalls): array + { + return array_map(fn (array $toolCall): ToolCall => new ToolCall( + id: data_get($toolCall, 'id'), + name: data_get($toolCall, 'name'), + arguments: data_get($toolCall, 'arguments'), + ), $toolCalls); + } + + /** + * @throws ConnectionException + */ + protected function sendRequest(Request $request): Response + { + /** @var Response $response */ + $response = $this->client->post( + 'chat/completions', + array_merge([ + 'stream' => true, + 'messages' => (new MessageMap($request->messages(), $request->systemPrompts()))(), + 'max_tokens' => $request->maxTokens(), + ], Arr::whereNotNull([ + 'temperature' => $request->temperature(), + 'top_p' => $request->topP(), + 'tools' => ToolMap::map($request->tools()) ?: null, + 'tool_choice' => ToolChoiceMap::map($request->toolChoice()), + ])) + ); + + return $response; + } + + protected function readLine(StreamInterface $stream): string + { + $buffer = ''; + + while (! $stream->eof()) { + $byte = $stream->read(1); + + if ($byte === '') { + return $buffer; + } + + $buffer .= $byte; + + if ($byte === "\n") { + break; + } + } + + return $buffer; + } +} diff --git a/src/Providers/Azure/Handlers/Structured.php b/src/Providers/Azure/Handlers/Structured.php new file mode 100644 index 000000000..536162f5b --- /dev/null +++ b/src/Providers/Azure/Handlers/Structured.php @@ -0,0 +1,115 @@ +responseBuilder = new ResponseBuilder; + } + + public function handle(Request $request): StructuredResponse + { + $request = $this->appendMessageForJsonMode($request); + + $data = $this->sendRequest($request); + + $this->validateResponse($data); + + return $this->createResponse($request, $data); + } + + /** + * @return array + */ + protected function sendRequest(Request $request): array + { + /** @var \Illuminate\Http\Client\Response $response */ + $response = $this->client->post( + 'chat/completions', + array_merge([ + 'messages' => (new MessageMap($request->messages(), $request->systemPrompts()))(), + 'max_tokens' => $request->maxTokens(), + ], Arr::whereNotNull([ + 'temperature' => $request->temperature(), + 'top_p' => $request->topP(), + 'response_format' => ['type' => 'json_object'], + ])) + ); + + return $response->json(); + } + + /** + * @param array $data + */ + protected function validateResponse(array $data): void + { + if ($data === []) { + throw PrismException::providerResponseError('Azure Error: Empty response'); + } + } + + /** + * @param array $data + */ + protected function createResponse(Request $request, array $data): StructuredResponse + { + $text = data_get($data, 'choices.0.message.content') ?? ''; + + $responseMessage = new AssistantMessage($text); + $request->addMessage($responseMessage); + + $step = new Step( + text: $text, + finishReason: FinishReasonMap::map(data_get($data, 'choices.0.finish_reason', '')), + usage: new Usage( + data_get($data, 'usage.prompt_tokens'), + data_get($data, 'usage.completion_tokens'), + ), + meta: new Meta( + id: data_get($data, 'id'), + model: data_get($data, 'model'), + ), + messages: $request->messages(), + systemPrompts: $request->systemPrompts(), + additionalContent: [], + ); + + $this->responseBuilder->addStep($step); + + return $this->responseBuilder->toResponse(); + } + + protected function appendMessageForJsonMode(Request $request): Request + { + return $request->addMessage(new SystemMessage(sprintf( + "You MUST respond EXCLUSIVELY with a JSON object that strictly adheres to the following schema. \n Do NOT explain or add other content. Validate your response against this schema \n %s", + json_encode($request->schema()->toArray(), JSON_PRETTY_PRINT) + ))); + } +} diff --git a/src/Providers/Azure/Handlers/Text.php b/src/Providers/Azure/Handlers/Text.php new file mode 100644 index 000000000..a0a964f6c --- /dev/null +++ b/src/Providers/Azure/Handlers/Text.php @@ -0,0 +1,147 @@ +responseBuilder = new ResponseBuilder; + } + + public function handle(Request $request): TextResponse + { + $data = $this->sendRequest($request); + + $this->validateResponse($data); + + $responseMessage = new AssistantMessage( + data_get($data, 'choices.0.message.content') ?? '', + ToolCallMap::map(data_get($data, 'choices.0.message.tool_calls', [])), + [] + ); + + $request = $request->addMessage($responseMessage); + + return match ($this->mapFinishReason($data)) { + FinishReason::ToolCalls => $this->handleToolCalls($data, $request), + FinishReason::Stop => $this->handleStop($data, $request), + FinishReason::Length => throw new PrismException('Azure: max tokens exceeded'), + FinishReason::ContentFilter => throw new PrismException('Azure: content filter triggered'), + default => throw new PrismException('Azure: unknown finish reason'), + }; + } + + /** + * @param array $data + */ + protected function handleToolCalls(array $data, Request $request): TextResponse + { + $toolResults = $this->callTools( + $request->tools(), + ToolCallMap::map(data_get($data, 'choices.0.message.tool_calls', [])) + ); + + $request = $request->addMessage(new ToolResultMessage($toolResults)); + + $this->addStep($data, $request, $toolResults); + + if ($this->shouldContinue($request)) { + return $this->handle($request); + } + + return $this->responseBuilder->toResponse(); + } + + /** + * @param array $data + */ + protected function handleStop(array $data, Request $request): TextResponse + { + $this->addStep($data, $request); + + return $this->responseBuilder->toResponse(); + } + + protected function shouldContinue(Request $request): bool + { + return $this->responseBuilder->steps->count() < $request->maxSteps(); + } + + /** + * @return array + */ + protected function sendRequest(Request $request): array + { + /** @var \Illuminate\Http\Client\Response $response */ + $response = $this->client->post( + 'chat/completions', + array_merge([ + 'messages' => (new MessageMap($request->messages(), $request->systemPrompts()))(), + 'max_tokens' => $request->maxTokens(), + ], Arr::whereNotNull([ + 'temperature' => $request->temperature(), + 'top_p' => $request->topP(), + 'tools' => ToolMap::map($request->tools()) ?: null, + 'tool_choice' => ToolChoiceMap::map($request->toolChoice()), + ])) + ); + + return $response->json(); + } + + /** + * @param array $data + * @param array $toolResults + */ + protected function addStep(array $data, Request $request, array $toolResults = []): void + { + $this->responseBuilder->addStep(new Step( + text: data_get($data, 'choices.0.message.content') ?? '', + finishReason: $this->mapFinishReason($data), + toolCalls: ToolCallMap::map(data_get($data, 'choices.0.message.tool_calls', [])), + toolResults: $toolResults, + providerToolCalls: [], + usage: new Usage( + data_get($data, 'usage.prompt_tokens'), + data_get($data, 'usage.completion_tokens'), + ), + meta: new Meta( + id: data_get($data, 'id'), + model: data_get($data, 'model'), + ), + messages: $request->messages(), + systemPrompts: $request->systemPrompts(), + additionalContent: [], + )); + } +} diff --git a/src/Providers/Azure/Maps/FinishReasonMap.php b/src/Providers/Azure/Maps/FinishReasonMap.php new file mode 100644 index 000000000..af08d6547 --- /dev/null +++ b/src/Providers/Azure/Maps/FinishReasonMap.php @@ -0,0 +1,21 @@ + FinishReason::Stop, + 'tool_calls' => FinishReason::ToolCalls, + 'length' => FinishReason::Length, + 'content_filter' => FinishReason::ContentFilter, + default => FinishReason::Unknown, + }; + } +} diff --git a/src/Providers/Azure/Maps/ImageMapper.php b/src/Providers/Azure/Maps/ImageMapper.php new file mode 100644 index 000000000..0543c1fac --- /dev/null +++ b/src/Providers/Azure/Maps/ImageMapper.php @@ -0,0 +1,44 @@ + + */ + public function toPayload(): array + { + return [ + 'type' => 'image_url', + 'image_url' => [ + 'url' => $this->media->isUrl() + ? $this->media->url() + : sprintf('data:%s;base64,%s', $this->media->mimeType(), $this->media->base64()), + ], + ]; + } + + protected function provider(): string|Provider + { + return Provider::Azure; + } + + protected function validateMedia(): bool + { + if ($this->media->isUrl()) { + return true; + } + + return $this->media->hasRawContent(); + } +} diff --git a/src/Providers/Azure/Maps/MessageMap.php b/src/Providers/Azure/Maps/MessageMap.php new file mode 100644 index 000000000..a10e1a099 --- /dev/null +++ b/src/Providers/Azure/Maps/MessageMap.php @@ -0,0 +1,108 @@ + */ + protected array $mappedMessages = []; + + /** + * @param array $messages + * @param SystemMessage[] $systemPrompts + */ + public function __construct( + protected array $messages, + protected array $systemPrompts + ) { + $this->messages = array_merge( + $this->systemPrompts, + $this->messages + ); + } + + /** + * @return array + */ + public function __invoke(): array + { + array_map( + $this->mapMessage(...), + $this->messages + ); + + return $this->mappedMessages; + } + + protected function mapMessage(Message $message): void + { + match ($message::class) { + UserMessage::class => $this->mapUserMessage($message), + AssistantMessage::class => $this->mapAssistantMessage($message), + ToolResultMessage::class => $this->mapToolResultMessage($message), + SystemMessage::class => $this->mapSystemMessage($message), + default => throw new Exception('Could not map message type '.$message::class), + }; + } + + protected function mapSystemMessage(SystemMessage $message): void + { + $this->mappedMessages[] = [ + 'role' => 'system', + 'content' => $message->content, + ]; + } + + protected function mapToolResultMessage(ToolResultMessage $message): void + { + foreach ($message->toolResults as $toolResult) { + $this->mappedMessages[] = [ + 'role' => 'tool', + 'tool_call_id' => $toolResult->toolCallId, + 'content' => $toolResult->result, + ]; + } + } + + protected function mapUserMessage(UserMessage $message): void + { + $imageParts = array_map(fn (Image $image): array => (new ImageMapper($image))->toPayload(), $message->images()); + + $this->mappedMessages[] = [ + 'role' => 'user', + 'content' => [ + ['type' => 'text', 'text' => $message->text()], + ...$imageParts, + ], + ]; + } + + protected function mapAssistantMessage(AssistantMessage $message): void + { + $toolCalls = array_map(fn (ToolCall $toolCall): array => [ + 'id' => $toolCall->id, + 'type' => 'function', + 'function' => [ + 'name' => $toolCall->name, + 'arguments' => json_encode($toolCall->arguments()), + ], + ], $message->toolCalls); + + $this->mappedMessages[] = array_filter([ + 'role' => 'assistant', + 'content' => $message->content, + 'tool_calls' => $toolCalls, + ]); + } +} diff --git a/src/Providers/Azure/Maps/ToolCallMap.php b/src/Providers/Azure/Maps/ToolCallMap.php new file mode 100644 index 000000000..d15a02ecf --- /dev/null +++ b/src/Providers/Azure/Maps/ToolCallMap.php @@ -0,0 +1,23 @@ +> $toolCalls + * @return array + */ + public static function map(array $toolCalls): array + { + return array_map(fn (array $toolCall): ToolCall => new ToolCall( + id: data_get($toolCall, 'id'), + name: data_get($toolCall, 'function.name'), + arguments: data_get($toolCall, 'function.arguments'), + ), $toolCalls); + } +} diff --git a/src/Providers/Azure/Maps/ToolChoiceMap.php b/src/Providers/Azure/Maps/ToolChoiceMap.php new file mode 100644 index 000000000..4d4ac14fb --- /dev/null +++ b/src/Providers/Azure/Maps/ToolChoiceMap.php @@ -0,0 +1,32 @@ +|string|null + */ + public static function map(string|ToolChoice|null $toolChoice): string|array|null + { + if (is_string($toolChoice)) { + return [ + 'type' => 'function', + 'function' => [ + 'name' => $toolChoice, + ], + ]; + } + + return match ($toolChoice) { + ToolChoice::Auto => 'auto', + null => $toolChoice, + default => throw new InvalidArgumentException('Invalid tool choice') + }; + } +} diff --git a/src/Providers/Azure/Maps/ToolMap.php b/src/Providers/Azure/Maps/ToolMap.php new file mode 100644 index 000000000..921c788e8 --- /dev/null +++ b/src/Providers/Azure/Maps/ToolMap.php @@ -0,0 +1,33 @@ + + */ + public static function map(array $tools): array + { + return array_map(fn (Tool $tool): array => array_filter([ + 'type' => 'function', + 'function' => [ + 'name' => $tool->name(), + 'description' => $tool->description(), + ...$tool->hasParameters() ? [ + 'parameters' => [ + 'type' => 'object', + 'properties' => $tool->parametersAsArray(), + 'required' => $tool->requiredParameters(), + ], + ] : [], + ], + 'strict' => $tool->providerOptions('strict'), + ]), $tools); + } +}