diff --git a/src/main/java/org/sopt/kareer/global/external/ai/builder/context/MemberContextBuilder.java b/src/main/java/org/sopt/kareer/global/external/ai/builder/context/MemberContextBuilder.java index ef2f4672..b9b1e5b9 100644 --- a/src/main/java/org/sopt/kareer/global/external/ai/builder/context/MemberContextBuilder.java +++ b/src/main/java/org/sopt/kareer/global/external/ai/builder/context/MemberContextBuilder.java @@ -9,6 +9,7 @@ import org.sopt.kareer.domain.member.repository.MemberVisaRepository; import org.springframework.stereotype.Component; +import java.time.LocalDate; import java.util.List; import static org.sopt.kareer.domain.member.exception.MemberErrorCode.MEMBER_NOT_FOUND; @@ -27,17 +28,23 @@ public MemberAndContext load(Long memberId) { StringBuilder sb = new StringBuilder(); sb.append("User Profile\n"); - sb.append("- name: ").append(nullToEmpty(member.getName())).append("\n"); - sb.append("- country: ").append(member.getCountry() != null ? member.getCountry().name() : "").append("\n"); - sb.append("- primaryMajor: ").append(nullToEmpty(member.getPrimaryMajor())).append("\n"); - sb.append("- secondaryMajor: ").append(nullToEmpty(member.getSecondaryMajor())).append("\n"); - sb.append("- targetJob: ").append(nullToEmpty(member.getTargetJob())).append("\n"); - sb.append("- languageLevel: ").append(member.getLanguageLevel() != null ? member.getLanguageLevel().name() : "").append("\n"); - sb.append("- degree: ").append(member.getDegree() != null ? member.getDegree().name() : "").append("\n"); - sb.append("- graduationDate: ").append(member.getGraduationDate() != null ? member.getGraduationDate() : "").append("\n"); - sb.append("- expectedGraduationDate: ").append(member.getExpectedGraduationDate() != null ? member.getExpectedGraduationDate() : "").append("\n"); - sb.append("- targetJobSkill: ").append(nullToEmpty(member.getTargetJobSkill())).append("\n"); - sb.append("- personalBackground: ").append(nullToEmpty(member.getPersonalBackground())).append("\n"); + appendLine(sb, "name", member.getName()); + appendLine(sb, "email", member.getEmail()); + appendLine(sb, "birthDate", member.getBirthDate()); + appendLine(sb, "country", member.getCountry() != null ? member.getCountry().name() : ""); + appendLine(sb, "university", member.getUniversity()); + appendLine(sb, "primaryMajor", member.getPrimaryMajor()); + appendLine(sb, "secondaryMajor", member.getSecondaryMajor()); + appendLine(sb, "targetJob", member.getTargetJob()); + appendLine(sb, "targetJobSkill", member.getTargetJobSkill()); + appendLine(sb, "fieldsOfInterest", member.getFieldsOfInterest()); + appendLine(sb, "preparationStatus", member.getPreparationStatus()); + appendLine(sb, "personalBackground", member.getPersonalBackground()); + appendLine(sb, "languageLevel", member.getLanguageLevel() != null ? member.getLanguageLevel().name() : ""); + appendLine(sb, "englishLevel", member.getEnglishLevel() != null ? member.getEnglishLevel().name() : ""); + appendLine(sb, "degree", member.getDegree() != null ? member.getDegree().name() : ""); + appendLine(sb, "graduationDate", member.getGraduationDate()); + appendLine(sb, "expectedGraduationDate", member.getExpectedGraduationDate()); sb.append("Visa Info\n"); for (MemberVisa v : visas) { @@ -56,5 +63,17 @@ private String nullToEmpty(String s) { return s == null ? "" : s; } + private void appendLine(StringBuilder sb, String key, String value) { + sb.append("- ").append(key).append(": ").append(nullToEmpty(value)).append("\n"); + } + + private void appendLine(StringBuilder sb, String key, LocalDate value) { + sb.append("- ").append(key).append(": ").append(toText(value)).append("\n"); + } + + private String toText(LocalDate value) { + return value == null ? "" : value.toString(); + } + public record MemberAndContext(Member member, String contextText) {} } diff --git a/src/main/java/org/sopt/kareer/global/external/ai/service/PolicyDocumentRetriever.java b/src/main/java/org/sopt/kareer/global/external/ai/service/PolicyDocumentRetriever.java index 43e41ea1..2a1836a9 100644 --- a/src/main/java/org/sopt/kareer/global/external/ai/service/PolicyDocumentRetriever.java +++ b/src/main/java/org/sopt/kareer/global/external/ai/service/PolicyDocumentRetriever.java @@ -5,6 +5,7 @@ import org.sopt.kareer.domain.member.entity.MemberVisa; import org.sopt.kareer.global.external.ai.builder.query.PolicyQueryBuilder; import org.sopt.kareer.global.external.ai.properties.RoadmapRagProperties; +import org.sopt.kareer.global.external.cohere.service.CohereRerankClient; import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.pgvector.PgVectorStore; @@ -18,16 +19,24 @@ public class PolicyDocumentRetriever { private final PgVectorStore policyDocumentVectorStore; private final RoadmapRagProperties props; + private final CohereRerankClient cohereRerankClient; public List retrievePolicy(Member member, MemberVisa visa) { String query = PolicyQueryBuilder.buildPolicyQuery(member, visa); - return policyDocumentVectorStore.similaritySearch( + List candidates = policyDocumentVectorStore.similaritySearch( SearchRequest.builder() .query(query) - .topK(props.policyTopK()) + .topK(props.candidatePoolTopK()) .build() ); + + List reranked = cohereRerankClient.rerank(query, candidates, props.policyTopK()); + if(reranked.size() > props.policyTopK()) { + return reranked.subList(0, props.policyTopK()); + } + return reranked; + } } diff --git a/src/main/java/org/sopt/kareer/global/external/cohere/dto/request/CohereRerankRequest.java b/src/main/java/org/sopt/kareer/global/external/cohere/dto/request/CohereRerankRequest.java new file mode 100644 index 00000000..a4f86fba --- /dev/null +++ b/src/main/java/org/sopt/kareer/global/external/cohere/dto/request/CohereRerankRequest.java @@ -0,0 +1,13 @@ +package org.sopt.kareer.global.external.cohere.dto.request; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +public record CohereRerankRequest( + String model, + String query, + List documents, + @JsonProperty("top_n") + Integer topN +) {} diff --git a/src/main/java/org/sopt/kareer/global/external/cohere/dto/response/CohereRerankResponse.java b/src/main/java/org/sopt/kareer/global/external/cohere/dto/response/CohereRerankResponse.java new file mode 100644 index 00000000..9cf6d907 --- /dev/null +++ b/src/main/java/org/sopt/kareer/global/external/cohere/dto/response/CohereRerankResponse.java @@ -0,0 +1,16 @@ +package org.sopt.kareer.global.external.cohere.dto.response; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +public record CohereRerankResponse( + String id, + List results +) { + public record Result( + Integer index, + @JsonProperty("relevance_score") + Double relevanceScore + ) {} +} diff --git a/src/main/java/org/sopt/kareer/global/external/cohere/properties/CohereProperties.java b/src/main/java/org/sopt/kareer/global/external/cohere/properties/CohereProperties.java new file mode 100644 index 00000000..36efacb1 --- /dev/null +++ b/src/main/java/org/sopt/kareer/global/external/cohere/properties/CohereProperties.java @@ -0,0 +1,15 @@ +package org.sopt.kareer.global.external.cohere.properties; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +@ConfigurationProperties(prefix = "cohere") +public record CohereProperties( + String apiKey, + String baseUrl, + Rerank rerank +) { + public record Rerank( + String model, + int topN + ) {} +} diff --git a/src/main/java/org/sopt/kareer/global/external/cohere/service/CohereRerankClient.java b/src/main/java/org/sopt/kareer/global/external/cohere/service/CohereRerankClient.java new file mode 100644 index 00000000..c5c95892 --- /dev/null +++ b/src/main/java/org/sopt/kareer/global/external/cohere/service/CohereRerankClient.java @@ -0,0 +1,74 @@ +package org.sopt.kareer.global.external.cohere.service; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.sopt.kareer.global.external.cohere.dto.request.CohereRerankRequest; +import org.sopt.kareer.global.external.cohere.dto.response.CohereRerankResponse; +import org.sopt.kareer.global.external.cohere.properties.CohereProperties; +import org.springframework.ai.document.Document; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.stereotype.Component; +import org.springframework.web.client.RestClient; + +import java.util.ArrayList; +import java.util.List; + +@Slf4j +@Component +@RequiredArgsConstructor +public class CohereRerankClient { + + private final CohereProperties cohereProperties; + + private RestClient restClient() { + return RestClient.builder() + .baseUrl(cohereProperties.baseUrl()) + .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + cohereProperties.apiKey()) + .defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .build(); + } + + public List rerank(String query, List documents, Integer topN) { + if (documents == null || documents.isEmpty()) { + return List.of(); + } + + List serializedDocs = documents.stream() + .map(Document::getText) + .toList(); + + CohereRerankRequest request = new CohereRerankRequest( + cohereProperties.rerank().model(), + query, + serializedDocs, + topN != null + ? Math.min(topN, documents.size()) + : Math.min(cohereProperties.rerank().topN(), documents.size()) + ); + + try { + CohereRerankResponse response = restClient() + .post() + .uri("/v2/rerank") + .body(request) + .retrieve() + .body(CohereRerankResponse.class); + + if (response == null || response.results() == null || response.results().isEmpty()) { + log.warn("Cohere rerank response empty. query={}", query); + return documents; + } + + List reranked = new ArrayList<>(); + for (CohereRerankResponse.Result result : response.results()) { + reranked.add(documents.get(result.index())); + } + return reranked; + + } catch (Exception e) { + log.error("Cohere rerank failed. fallback to original order. query={}", query, e); + return documents; + } + } +} \ No newline at end of file diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index e094aa90..ba2e40db 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -107,6 +107,13 @@ management: enabled: true health: show-details: always + +cohere: + api-key: ${COHERE_API_KEY} + base-url: https://api.cohere.com + rerank: + model: rerank-multilingual-v3.0 + top-n: 10 --- spring: config: