diff --git a/frontend/src/components/experiment_builder/structured_prompt_editor.ts b/frontend/src/components/experiment_builder/structured_prompt_editor.ts index 10cf5c752..666b8537a 100644 --- a/frontend/src/components/experiment_builder/structured_prompt_editor.ts +++ b/frontend/src/components/experiment_builder/structured_prompt_editor.ts @@ -26,7 +26,6 @@ import { SeedStrategy, ShuffleConfig, StageContextPromptItem, - StageKind, TextPromptItem, } from '@deliberation-lab/utils'; @@ -54,12 +53,6 @@ export class EditorComponent extends MobxLitElement { ); } - /** Check if the current stage supports conditions (only private chat, not group chat). */ - private supportsConditions(): boolean { - const stage = this.experimentEditor.getStage(this.stageId); - return stage?.kind === StageKind.PRIVATE_CHAT; - } - override render() { return this.renderPromptPreview(); } @@ -213,8 +206,7 @@ export class EditorComponent extends MobxLitElement { } const conditionTargets = this.getConditionTargets(); - const supportsConditions = - this.supportsConditions() && conditionTargets.length > 0; + const hasConditionTargets = conditionTargets.length > 0; return items.map((item, index) => { const hasCondition = item.condition !== undefined; @@ -224,7 +216,7 @@ export class EditorComponent extends MobxLitElement {
${this.renderItemEditor(item)}
- ${supportsConditions && item.type !== PromptItemType.GROUP + ${hasConditionTargets && item.type !== PromptItemType.GROUP ? html` void = () => {}; @@ -65,6 +71,8 @@ export class ConditionEditor extends MobxLitElement { private renderCondition(condition: Condition): TemplateResult { if (condition.type === 'group') { return this.renderConditionGroup(condition); + } else if (condition.type === 'aggregation') { + return this.renderAggregationCondition(condition); } else { return this.renderComparisonCondition(condition); } @@ -151,6 +159,16 @@ export class ConditionEditor extends MobxLitElement { add Add condition + ${this.allowAggregation + ? html` + this.addAggregationToGroup(group)} + > + groups + Add aggregation + + ` + : nothing} this.addSubgroupToGroup(group)}> add_circle Add subgroup @@ -290,6 +308,229 @@ export class ConditionEditor extends MobxLitElement { } } + private renderAggregationCondition(condition: AggregationCondition) { + const conditionKey = this.getTargetKey(condition.target); + const target = this.targets.find( + (t) => this.getTargetKey(t.ref) === conditionKey, + ); + + const needsFilterComparison = + condition.aggregator === AggregationOperator.COUNT || + condition.aggregator === AggregationOperator.SUM || + condition.aggregator === AggregationOperator.AVERAGE; + + return html` +
+
+ + this.updateAggregationOperator( + condition, + (e.target as HTMLSelectElement).value as AggregationOperator, + )} + > + ${Object.values(AggregationOperator).map( + (op) => html` + +
${getAggregationOperatorLabel(op)}
+
+ `, + )} +
+ + where + + + this.updateAggregationTarget( + condition, + (e.target as HTMLSelectElement).value, + )} + > + ${this.targets.map( + (t) => html` + +
+ ${t.stageName ? `[${t.stageName}] ${t.label}` : t.label} +
+
+ `, + )} +
+
+ + ${needsFilterComparison + ? html` +
+ Filter: value + + this.updateAggregationFilterOperator( + condition, + (e.target as HTMLSelectElement) + .value as ComparisonOperator, + )} + > + ${this.getAvailableOperators(target?.type).map( + (op) => html` + +
+ ${getComparisonOperatorLabel(op)} +
+
+ `, + )} +
+ + this.updateAggregationFilterValue( + condition, + Number((e.target as HTMLInputElement).value), + )} + placeholder="Value" + > +
+ ` + : nothing} + +
+ Result + + this.updateAggregationComparisonOperator( + condition, + (e.target as HTMLSelectElement).value as ComparisonOperator, + )} + > + ${this.getAggregationResultOperators(condition.aggregator).map( + (op) => html` + +
${getComparisonOperatorLabel(op)}
+
+ `, + )} +
+ ${this.renderAggregationValue(condition)} +
+
+ `; + } + + private renderAggregationValue(condition: AggregationCondition) { + // For ANY/ALL/NONE, show target-appropriate value selector + // For COUNT/SUM/AVERAGE, show number input + if ( + condition.aggregator === AggregationOperator.COUNT || + condition.aggregator === AggregationOperator.SUM || + condition.aggregator === AggregationOperator.AVERAGE + ) { + return html` + + this.updateAggregationValue( + condition, + Number((e.target as HTMLInputElement).value), + )} + placeholder="Threshold" + > + `; + } + + // For ANY/ALL/NONE, use the target type for value selection + const conditionKey = this.getTargetKey(condition.target); + const target = this.targets.find( + (t) => this.getTargetKey(t.ref) === conditionKey, + ); + + if (!target) return nothing; + + if (target.type === 'boolean') { + return html` + + this.updateAggregationValue( + condition, + (e.target as HTMLSelectElement).value === 'true', + )} + > + +
Yes/Checked
+
+ +
No/Unchecked
+
+
+ `; + } else if (target.type === 'choice' && target.choices) { + return html` + + this.updateAggregationValue( + condition, + (e.target as HTMLSelectElement).value, + )} + > + ${target.choices.map( + (choice) => html` + +
${choice.label}
+
+ `, + )} +
+ `; + } else { + return html` + + this.updateAggregationValue( + condition, + target.type === 'number' + ? Number((e.target as HTMLInputElement).value) + : (e.target as HTMLInputElement).value, + )} + placeholder="Value" + > + `; + } + } + + private getAggregationResultOperators( + _aggregator: AggregationOperator, + ): ComparisonOperator[] { + // All aggregation types support the same comparison operators + return [ + ComparisonOperator.EQUALS, + ComparisonOperator.NOT_EQUALS, + ComparisonOperator.GREATER_THAN, + ComparisonOperator.GREATER_THAN_OR_EQUAL, + ComparisonOperator.LESS_THAN, + ComparisonOperator.LESS_THAN_OR_EQUAL, + ]; + } + private getAvailableOperators(type?: string): ComparisonOperator[] { if (!type) return [ComparisonOperator.EQUALS]; @@ -373,6 +614,26 @@ export class ConditionEditor extends MobxLitElement { } } + private createDefaultAggregation(): AggregationCondition | null { + const firstTarget = this.targets[0]; + if (!firstTarget) return null; + + return createAggregationCondition( + firstTarget.ref, + AggregationOperator.ANY, + ComparisonOperator.EQUALS, + this.getDefaultValue(firstTarget), + ); + } + + private addAggregationToGroup(group: ConditionGroup) { + const aggregation = this.createDefaultAggregation(); + if (aggregation) { + group.conditions.push(aggregation); + this.onConditionChange(this.condition); + } + } + private createSubgroupWithComparison(): ConditionGroup { const comparison = this.createDefaultComparison(); const conditions = comparison ? [comparison] : []; @@ -432,4 +693,95 @@ export class ConditionEditor extends MobxLitElement { condition.value = value; this.onConditionChange(this.condition); } + + // Aggregation condition update methods + private updateAggregationOperator( + condition: AggregationCondition, + aggregator: AggregationOperator, + ) { + condition.aggregator = aggregator; + + // Initialize filterComparison for COUNT/SUM/AVERAGE if not present + if ( + (aggregator === AggregationOperator.COUNT || + aggregator === AggregationOperator.SUM || + aggregator === AggregationOperator.AVERAGE) && + !condition.filterComparison + ) { + condition.filterComparison = { + operator: ComparisonOperator.GREATER_THAN, + value: 0, + }; + // Reset result value to 0 for numeric aggregations + condition.value = 0; + } + + this.onConditionChange(this.condition); + } + + private updateAggregationTarget( + condition: AggregationCondition, + targetKey: string, + ) { + const target = this.targets.find( + (t) => this.getTargetKey(t.ref) === targetKey, + ); + + if (target) { + condition.target = target.ref; + // Reset value based on target type (for non-COUNT/SUM/AVERAGE) + if ( + condition.aggregator !== AggregationOperator.COUNT && + condition.aggregator !== AggregationOperator.SUM && + condition.aggregator !== AggregationOperator.AVERAGE + ) { + condition.value = this.getDefaultValue(target); + } + } + + this.onConditionChange(this.condition); + } + + private updateAggregationComparisonOperator( + condition: AggregationCondition, + operator: ComparisonOperator, + ) { + condition.operator = operator; + this.onConditionChange(this.condition); + } + + private updateAggregationValue( + condition: AggregationCondition, + value: string | number | boolean, + ) { + condition.value = value; + this.onConditionChange(this.condition); + } + + private updateAggregationFilterOperator( + condition: AggregationCondition, + operator: ComparisonOperator, + ) { + if (!condition.filterComparison) { + condition.filterComparison = {operator, value: 0}; + } else { + condition.filterComparison.operator = operator; + } + this.onConditionChange(this.condition); + } + + private updateAggregationFilterValue( + condition: AggregationCondition, + value: number, + ) { + if (!condition.filterComparison) { + condition.filterComparison = { + operator: ComparisonOperator.GREATER_THAN, + value, + }; + } else { + condition.filterComparison.value = value; + } + this.onConditionChange(this.condition); + } } diff --git a/frontend/src/components/stages/survey_per_participant_view.ts b/frontend/src/components/stages/survey_per_participant_view.ts index 7af7d6e86..851fe1755 100644 --- a/frontend/src/components/stages/survey_per_participant_view.ts +++ b/frontend/src/components/stages/survey_per_participant_view.ts @@ -29,6 +29,9 @@ import { SurveyAnswer, SurveyPerParticipantStageConfig, SurveyStageParticipantAnswer, + SurveyStagePublicData, + StageKind, + StageParticipantAnswer, TextSurveyAnswer, TextSurveyQuestion, isMultipleChoiceImageQuestion, @@ -75,6 +78,46 @@ export class SurveyView extends MobxLitElement { ); } + /** + * Build multi-participant answers from the cohort's public stage data. + * Used for evaluating aggregation conditions. + */ + private getAllParticipantAnswers(): Record< + string, + Record + > { + const result: Record> = {}; + + // Iterate over all public stage data in the cohort + for (const [stageId, publicData] of Object.entries( + this.cohortService.stagePublicDataMap, + )) { + if (publicData.kind === StageKind.SURVEY) { + const surveyPublicData = publicData as SurveyStagePublicData; + + // For each participant in this stage's public data + for (const [participantId, answerMap] of Object.entries( + surveyPublicData.participantAnswerMap, + )) { + // Initialize participant's answer record if needed + if (!result[participantId]) { + result[participantId] = {}; + } + + // Create a SurveyStageParticipantAnswer from the answer map + const stageAnswer: SurveyStageParticipantAnswer = { + id: stageId, + kind: StageKind.SURVEY, + answerMap: answerMap, + }; + result[participantId][stageId] = stageAnswer; + } + } + } + + return result; + } + private renderParticipant(profile: ParticipantProfile) { const isCurrent = profile.publicId === this.participantService.profile?.publicId; @@ -98,6 +141,7 @@ export class SurveyView extends MobxLitElement { if (!this.stage) return false; const participants = this.getParticipants(); const allStageAnswers = this.participantAnswerService.answerMap; + const allParticipantAnswers = this.getAllParticipantAnswers(); for (const participant of participants) { const answerMap = @@ -113,6 +157,7 @@ export class SurveyView extends MobxLitElement { answerMap, allStageAnswers, participant.publicId, + allParticipantAnswers, ); if (!isSurveyComplete(visibleQuestions, answerMap)) { @@ -205,12 +250,16 @@ export class SurveyView extends MobxLitElement { // Get all stage answers const allStageAnswers = this.participantAnswerService.answerMap; + // Get all participants' answers for aggregation conditions + const allParticipantAnswers = this.getAllParticipantAnswers(); + return isQuestionVisible( question, this.stage.id, currentAnswers, allStageAnswers, participant.publicId, // Pass which participant is being evaluated + allParticipantAnswers, ); } diff --git a/frontend/src/components/stages/survey_view.ts b/frontend/src/components/stages/survey_view.ts index feeaa3316..04647f5a2 100644 --- a/frontend/src/components/stages/survey_view.ts +++ b/frontend/src/components/stages/survey_view.ts @@ -24,6 +24,10 @@ import { ScaleSurveyQuestion, SurveyQuestionKind, SurveyStageConfig, + SurveyStagePublicData, + StageKind, + StageParticipantAnswer, + SurveyStageParticipantAnswer, TextSurveyAnswer, TextSurveyQuestion, isMultipleChoiceImageQuestion, @@ -36,6 +40,7 @@ import { import {unsafeHTML} from 'lit/directives/unsafe-html.js'; import {convertMarkdownToHTML} from '../../shared/utils'; import {core} from '../../core/core'; +import {CohortService} from '../../services/cohort.service'; import {ParticipantService} from '../../services/participant.service'; import {ParticipantAnswerService} from '../../services/participant.answer'; @@ -46,6 +51,7 @@ import {styles} from './survey_view.scss'; export class SurveyView extends MobxLitElement { static override styles: CSSResultGroup = [styles]; + private readonly cohortService = core.getService(CohortService); private readonly participantService = core.getService(ParticipantService); private readonly participantAnswerService = core.getService( ParticipantAnswerService, @@ -53,6 +59,46 @@ export class SurveyView extends MobxLitElement { @property() stage: SurveyStageConfig | undefined = undefined; + /** + * Build multi-participant answers from the cohort's public stage data. + * Used for evaluating aggregation conditions. + */ + private getAllParticipantAnswers(): Record< + string, + Record + > { + const result: Record> = {}; + + // Iterate over all public stage data in the cohort + for (const [stageId, publicData] of Object.entries( + this.cohortService.stagePublicDataMap, + )) { + if (publicData.kind === StageKind.SURVEY) { + const surveyPublicData = publicData as SurveyStagePublicData; + + // For each participant in this stage's public data + for (const [participantId, answerMap] of Object.entries( + surveyPublicData.participantAnswerMap, + )) { + // Initialize participant's answer record if needed + if (!result[participantId]) { + result[participantId] = {}; + } + + // Create a SurveyStageParticipantAnswer from the answer map + const stageAnswer: SurveyStageParticipantAnswer = { + id: stageId, + kind: StageKind.SURVEY, + answerMap: answerMap, + }; + result[participantId][stageId] = stageAnswer; + } + } + } + + return result; + } + override render() { if (!this.stage) { return nothing; @@ -64,11 +110,15 @@ export class SurveyView extends MobxLitElement { const currentSurveyAnswers = this.participantAnswerService.getSurveyAnswerMap(this.stage.id); const allStageAnswers = this.participantAnswerService.answerMap; + const allParticipantAnswers = this.getAllParticipantAnswers(); + const currentParticipantId = this.participantService.profile?.publicId; const visibleQuestions = getVisibleSurveyQuestions( this.stage.questions, this.stage.id, currentSurveyAnswers, allStageAnswers, + currentParticipantId, + allParticipantAnswers, ); return isSurveyComplete(visibleQuestions, currentSurveyAnswers); @@ -110,11 +160,16 @@ export class SurveyView extends MobxLitElement { this.participantAnswerService.getSurveyAnswerMap(this.stage.id); const allStageAnswers = this.participantAnswerService.answerMap; + const allParticipantAnswers = this.getAllParticipantAnswers(); + const currentParticipantId = this.participantService.profile?.publicId; + return isQuestionVisible( question, this.stage.id, currentSurveyAnswers, allStageAnswers, + currentParticipantId, + allParticipantAnswers, ); } diff --git a/frontend/src/shared/condition_editor.utils.ts b/frontend/src/shared/condition_editor.utils.ts index 5c9ccaa4b..c6d51f70a 100644 --- a/frontend/src/shared/condition_editor.utils.ts +++ b/frontend/src/shared/condition_editor.utils.ts @@ -9,6 +9,8 @@ export interface RenderConditionEditorOptions { targets: ConditionTarget[]; canEdit: boolean; onConditionChange: (condition: Condition | undefined) => void; + /** Whether to allow aggregation conditions (defaults to true) */ + allowAggregation?: boolean; } /** @@ -16,7 +18,13 @@ export interface RenderConditionEditorOptions { * Returns nothing if there are no valid targets. */ export function renderConditionEditor(options: RenderConditionEditorOptions) { - const {condition, targets, canEdit, onConditionChange} = options; + const { + condition, + targets, + canEdit, + onConditionChange, + allowAggregation = true, + } = options; if (targets.length === 0) return nothing; @@ -25,6 +33,7 @@ export function renderConditionEditor(options: RenderConditionEditorOptions) { .condition=${condition} .targets=${targets} .disabled=${!canEdit} + .allowAggregation=${allowAggregation} .onConditionChange=${onConditionChange} > `; diff --git a/functions/src/structured_prompt.utils.test.ts b/functions/src/structured_prompt.utils.test.ts index 1eaf32f93..0029d6a97 100644 --- a/functions/src/structured_prompt.utils.test.ts +++ b/functions/src/structured_prompt.utils.test.ts @@ -4,6 +4,16 @@ import { MediatorProfileExtended, BasePromptConfig, PromptItemType, + PromptItemGroup, + TextPromptItem, + StageKind, + SurveyQuestionKind, + createComparisonCondition, + createConditionGroup, + createAggregationCondition, + ConditionOperator, + ComparisonOperator, + AggregationOperator, } from '@deliberation-lab/utils'; import {getFirestoreDataForStructuredPrompt} from './structured_prompt.utils'; import * as firestoreUtils from './utils/firestore'; @@ -251,4 +261,641 @@ describe('structured_prompt.utils', () => { expect(result.participants).toHaveLength(2); }); }); + + describe('condition filtering', () => { + const mockExperimentId = 'test-experiment'; + const mockCohortId = 'test-cohort'; + const mockStageId = 'chat-stage'; + const mockSurveyStageId = 'survey-stage'; + + const mockParticipant: ParticipantProfileExtended = { + id: 'participant-1', + privateId: 'participant-private-1', + publicId: 'participant-public-1', + type: UserType.PARTICIPANT, + name: 'Test Participant', + avatar: '🐶', + pronouns: 'they/them', + currentCohortId: mockCohortId, + currentExperimentId: mockExperimentId, + currentStageId: mockStageId, + timestamps: { + accountCreated: {seconds: 0, nanoseconds: 0}, + lastLogin: {seconds: 0, nanoseconds: 0}, + }, + agentConfig: null, + prolificId: null, + transferCohortId: null, + variableMap: {}, + }; + + const mockParticipant2: ParticipantProfileExtended = { + ...mockParticipant, + id: 'participant-2', + privateId: 'participant-private-2', + publicId: 'participant-public-2', + name: 'Test Participant 2', + avatar: '🐱', + }; + + const mockParticipant3: ParticipantProfileExtended = { + ...mockParticipant, + id: 'participant-3', + privateId: 'participant-private-3', + publicId: 'participant-public-3', + name: 'Test Participant 3', + avatar: '🐰', + }; + + const mockExperiment = { + id: mockExperimentId, + versionId: 'v1', + metadata: { + name: 'Test Experiment', + publicName: 'Test', + description: '', + tags: [], + date: {seconds: 0, nanoseconds: 0}, + }, + permissions: {experimenters: [], viewers: []}, + stageIds: [mockSurveyStageId, mockStageId], + prolific: { + enableProlificIntegration: false, + defaultRedirectCode: null, + completionCodeMap: {}, + }, + attentionCheckConfig: { + numFailed: 0, + }, + variableMap: {}, + variableConfigs: [], + }; + + const mockCohort = { + id: mockCohortId, + name: 'Test Cohort', + metadata: { + numberOfParticipants: 3, + publicName: null, + description: null, + date: {seconds: 0, nanoseconds: 0}, + }, + participantConfig: { + allowedParticipantIds: [], + disabledParticipantIds: [], + }, + variableMap: {}, + }; + + const mockSurveyStage = { + id: mockSurveyStageId, + kind: StageKind.SURVEY, + name: 'Survey', + descriptions: {primaryText: '', infoText: '', helpText: ''}, + questions: [{id: 'q1', kind: SurveyQuestionKind.SCALE}], + }; + + beforeEach(() => { + jest.clearAllMocks(); + + (firestoreUtils.getFirestoreExperiment as jest.Mock).mockResolvedValue( + mockExperiment, + ); + (firestoreUtils.getFirestoreCohort as jest.Mock).mockResolvedValue( + mockCohort, + ); + ( + firestoreUtils.getFirestoreActiveParticipants as jest.Mock + ).mockResolvedValue([ + mockParticipant, + mockParticipant2, + mockParticipant3, + ]); + (firestoreUtils.getFirestoreParticipant as jest.Mock).mockImplementation( + async (experimentId: string, participantId: string) => { + if (participantId === 'participant-private-1') return mockParticipant; + if (participantId === 'participant-private-2') + return mockParticipant2; + if (participantId === 'participant-private-3') + return mockParticipant3; + return undefined; + }, + ); + (firestoreUtils.getFirestoreStage as jest.Mock).mockImplementation( + async (experimentId: string, stageId: string) => { + if (stageId === mockSurveyStageId) return mockSurveyStage; + return undefined; + }, + ); + ( + firestoreUtils.getFirestoreAnswersForStage as jest.Mock + ).mockResolvedValue([]); + }); + + it('should return all items when no conditions are present', async () => { + const promptConfig: BasePromptConfig = { + type: StageKind.PRIVATE_CHAT, + prompt: [ + {type: PromptItemType.TEXT, text: 'Item 1'}, + {type: PromptItemType.TEXT, text: 'Item 2'}, + {type: PromptItemType.TEXT, text: 'Item 3'}, + ], + }; + + const result = await getFirestoreDataForStructuredPrompt( + mockExperimentId, + mockCohortId, + mockStageId, + mockParticipant, + promptConfig, + ); + + expect(result.filteredPromptItems).toHaveLength(3); + }); + + it('should filter out items with failing conditions', async () => { + // Mock survey answers where participant answered 3 + ( + firestoreUtils.getFirestoreAnswersForStage as jest.Mock + ).mockResolvedValue([ + { + participantPublicId: 'participant-public-1', + answer: { + kind: StageKind.SURVEY, + answerMap: { + q1: {kind: SurveyQuestionKind.SCALE, value: 3}, + }, + }, + }, + ]); + + const promptConfig: BasePromptConfig = { + type: StageKind.PRIVATE_CHAT, + prompt: [ + {type: PromptItemType.TEXT, text: 'Always shown'}, + { + type: PromptItemType.TEXT, + text: 'Show if q1 > 5', + condition: createComparisonCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + ComparisonOperator.GREATER_THAN, + 5, + ), + }, + { + type: PromptItemType.TEXT, + text: 'Show if q1 < 5', + condition: createComparisonCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + ComparisonOperator.LESS_THAN, + 5, + ), + }, + ], + }; + + const result = await getFirestoreDataForStructuredPrompt( + mockExperimentId, + mockCohortId, + mockStageId, + mockParticipant, + promptConfig, + ); + + // Should have 2 items: "Always shown" and "Show if q1 < 5" + expect(result.filteredPromptItems).toHaveLength(2); + expect( + result.filteredPromptItems.map((item) => + item.type === PromptItemType.TEXT ? item.text : null, + ), + ).toEqual(['Always shown', 'Show if q1 < 5']); + }); + + it('should filter nested GROUP items recursively', async () => { + ( + firestoreUtils.getFirestoreAnswersForStage as jest.Mock + ).mockResolvedValue([ + { + participantPublicId: 'participant-public-1', + answer: { + kind: StageKind.SURVEY, + answerMap: { + q1: {kind: SurveyQuestionKind.SCALE, value: 7}, + }, + }, + }, + ]); + + const promptConfig: BasePromptConfig = { + type: StageKind.PRIVATE_CHAT, + prompt: [ + { + type: PromptItemType.GROUP, + items: [ + {type: PromptItemType.TEXT, text: 'Group item 1'}, + { + type: PromptItemType.TEXT, + text: 'Group item 2 - conditional', + condition: createComparisonCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + ComparisonOperator.LESS_THAN, + 5, // Fails: 7 is not < 5 + ), + }, + {type: PromptItemType.TEXT, text: 'Group item 3'}, + ], + }, + ], + }; + + const result = await getFirestoreDataForStructuredPrompt( + mockExperimentId, + mockCohortId, + mockStageId, + mockParticipant, + promptConfig, + ); + + // GROUP should be present with filtered children + expect(result.filteredPromptItems).toHaveLength(1); + const group = result.filteredPromptItems[0] as PromptItemGroup; + expect(group.type).toBe(PromptItemType.GROUP); + expect(group.items).toHaveLength(2); + expect(group.items.map((item) => (item as TextPromptItem).text)).toEqual([ + 'Group item 1', + 'Group item 3', + ]); + }); + + it('should filter out entire GROUP when GROUP condition fails', async () => { + ( + firestoreUtils.getFirestoreAnswersForStage as jest.Mock + ).mockResolvedValue([ + { + participantPublicId: 'participant-public-1', + answer: { + kind: StageKind.SURVEY, + answerMap: { + q1: {kind: SurveyQuestionKind.SCALE, value: 3}, + }, + }, + }, + ]); + + const promptConfig: BasePromptConfig = { + type: StageKind.PRIVATE_CHAT, + prompt: [ + {type: PromptItemType.TEXT, text: 'Before group'}, + { + type: PromptItemType.GROUP, + condition: createComparisonCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + ComparisonOperator.GREATER_THAN, + 5, // Fails: 3 is not > 5 + ), + items: [ + {type: PromptItemType.TEXT, text: 'Group item 1'}, + {type: PromptItemType.TEXT, text: 'Group item 2'}, + ], + }, + {type: PromptItemType.TEXT, text: 'After group'}, + ], + }; + + const result = await getFirestoreDataForStructuredPrompt( + mockExperimentId, + mockCohortId, + mockStageId, + mockParticipant, + promptConfig, + ); + + // GROUP should be filtered out entirely + expect(result.filteredPromptItems).toHaveLength(2); + expect( + result.filteredPromptItems.map((item) => + item.type === PromptItemType.TEXT ? item.text : 'GROUP', + ), + ).toEqual(['Before group', 'After group']); + }); + + it('should support condition groups with AND operator', async () => { + ( + firestoreUtils.getFirestoreAnswersForStage as jest.Mock + ).mockResolvedValue([ + { + participantPublicId: 'participant-public-1', + answer: { + kind: StageKind.SURVEY, + answerMap: { + q1: {kind: SurveyQuestionKind.SCALE, value: 7}, + }, + }, + }, + ]); + + const promptConfig: BasePromptConfig = { + type: StageKind.PRIVATE_CHAT, + prompt: [ + { + type: PromptItemType.TEXT, + text: 'Show if 5 < q1 < 10', + condition: createConditionGroup(ConditionOperator.AND, [ + createComparisonCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + ComparisonOperator.GREATER_THAN, + 5, + ), + createComparisonCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + ComparisonOperator.LESS_THAN, + 10, + ), + ]), + }, + { + type: PromptItemType.TEXT, + text: 'Show if q1 > 10 AND q1 < 5 (impossible)', + condition: createConditionGroup(ConditionOperator.AND, [ + createComparisonCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + ComparisonOperator.GREATER_THAN, + 10, + ), + createComparisonCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + ComparisonOperator.LESS_THAN, + 5, + ), + ]), + }, + ], + }; + + const result = await getFirestoreDataForStructuredPrompt( + mockExperimentId, + mockCohortId, + mockStageId, + mockParticipant, + promptConfig, + ); + + expect(result.filteredPromptItems).toHaveLength(1); + expect((result.filteredPromptItems[0] as TextPromptItem).text).toBe( + 'Show if 5 < q1 < 10', + ); + }); + + it('should support aggregation conditions with ANY operator in group chat', async () => { + // Three participants: one answered 8, one answered 3, one answered 6 + ( + firestoreUtils.getFirestoreAnswersForStage as jest.Mock + ).mockResolvedValue([ + { + participantPublicId: 'participant-public-1', + answer: { + kind: StageKind.SURVEY, + answerMap: {q1: {kind: SurveyQuestionKind.SCALE, value: 8}}, + }, + }, + { + participantPublicId: 'participant-public-2', + answer: { + kind: StageKind.SURVEY, + answerMap: {q1: {kind: SurveyQuestionKind.SCALE, value: 3}}, + }, + }, + { + participantPublicId: 'participant-public-3', + answer: { + kind: StageKind.SURVEY, + answerMap: {q1: {kind: SurveyQuestionKind.SCALE, value: 6}}, + }, + }, + ]); + + const promptConfig: BasePromptConfig = { + type: StageKind.CHAT, // Group chat + prompt: [ + { + type: PromptItemType.TEXT, + text: 'Show if ANY participant answered > 7', + condition: createAggregationCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + AggregationOperator.ANY, + ComparisonOperator.GREATER_THAN, + 7, + ), + }, + { + type: PromptItemType.TEXT, + text: 'Show if ALL participants answered > 7', + condition: createAggregationCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + AggregationOperator.ALL, + ComparisonOperator.GREATER_THAN, + 7, + ), + }, + ], + }; + + const mediator: MediatorProfileExtended = { + id: 'mediator-1', + privateId: 'mediator-private-1', + publicId: 'mediator-public-1', + type: UserType.MEDIATOR, + name: 'Test Mediator', + avatar: '🤖', + pronouns: 'they/them', + currentCohortId: mockCohortId, + currentExperimentId: mockExperimentId, + currentStageId: mockStageId, + timestamps: { + accountCreated: {seconds: 0, nanoseconds: 0}, + lastLogin: {seconds: 0, nanoseconds: 0}, + }, + agentConfig: { + agentId: 'test-agent', + model: 'test-model', + apiKey: 'test-key', + promptContext: 'test context', + }, + prolificId: null, + transferCohortId: null, + variableMap: {}, + }; + + const result = await getFirestoreDataForStructuredPrompt( + mockExperimentId, + mockCohortId, + mockStageId, + mediator, + promptConfig, + ); + + // ANY > 7 passes (participant 1 has 8), ALL > 7 fails + expect(result.filteredPromptItems).toHaveLength(1); + expect((result.filteredPromptItems[0] as TextPromptItem).text).toBe( + 'Show if ANY participant answered > 7', + ); + }); + + it('should support aggregation conditions with COUNT operator', async () => { + // Three participants: values 8, 3, 6 + ( + firestoreUtils.getFirestoreAnswersForStage as jest.Mock + ).mockResolvedValue([ + { + participantPublicId: 'participant-public-1', + answer: { + kind: StageKind.SURVEY, + answerMap: {q1: {kind: SurveyQuestionKind.SCALE, value: 8}}, + }, + }, + { + participantPublicId: 'participant-public-2', + answer: { + kind: StageKind.SURVEY, + answerMap: {q1: {kind: SurveyQuestionKind.SCALE, value: 3}}, + }, + }, + { + participantPublicId: 'participant-public-3', + answer: { + kind: StageKind.SURVEY, + answerMap: {q1: {kind: SurveyQuestionKind.SCALE, value: 6}}, + }, + }, + ]); + + const promptConfig: BasePromptConfig = { + type: StageKind.CHAT, + prompt: [ + { + type: PromptItemType.TEXT, + text: 'Show if COUNT of answers > 5 is >= 2', + condition: createAggregationCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + AggregationOperator.COUNT, + ComparisonOperator.GREATER_THAN_OR_EQUAL, + 2, + {operator: ComparisonOperator.GREATER_THAN, value: 5}, // filterComparison + ), + }, + { + type: PromptItemType.TEXT, + text: 'Show if COUNT of answers > 5 is >= 3', + condition: createAggregationCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + AggregationOperator.COUNT, + ComparisonOperator.GREATER_THAN_OR_EQUAL, + 3, + {operator: ComparisonOperator.GREATER_THAN, value: 5}, + ), + }, + ], + }; + + const mediator: MediatorProfileExtended = { + id: 'mediator-1', + privateId: 'mediator-private-1', + publicId: 'mediator-public-1', + type: UserType.MEDIATOR, + name: 'Test Mediator', + avatar: '🤖', + pronouns: 'they/them', + currentCohortId: mockCohortId, + currentExperimentId: mockExperimentId, + currentStageId: mockStageId, + timestamps: { + accountCreated: {seconds: 0, nanoseconds: 0}, + lastLogin: {seconds: 0, nanoseconds: 0}, + }, + agentConfig: { + agentId: 'test-agent', + model: 'test-model', + apiKey: 'test-key', + promptContext: 'test context', + }, + prolificId: null, + transferCohortId: null, + variableMap: {}, + }; + + const result = await getFirestoreDataForStructuredPrompt( + mockExperimentId, + mockCohortId, + mockStageId, + mediator, + promptConfig, + ); + + // Values > 5: [8, 6] = 2 items + // COUNT >= 2: passes + // COUNT >= 3: fails + expect(result.filteredPromptItems).toHaveLength(1); + expect((result.filteredPromptItems[0] as TextPromptItem).text).toBe( + 'Show if COUNT of answers > 5 is >= 2', + ); + }); + + it('should not fetch condition deps when no conditions present', async () => { + const promptConfig: BasePromptConfig = { + type: StageKind.PRIVATE_CHAT, + prompt: [{type: PromptItemType.TEXT, text: 'No conditions'}], + }; + + await getFirestoreDataForStructuredPrompt( + mockExperimentId, + mockCohortId, + mockStageId, + mockParticipant, + promptConfig, + ); + + // Should not fetch stage config or answers for conditions + expect(firestoreUtils.getFirestoreStage).not.toHaveBeenCalled(); + expect(firestoreUtils.getFirestoreAnswersForStage).not.toHaveBeenCalled(); + }); + + it('should fetch condition deps only for stages referenced in conditions', async () => { + const anotherSurveyStageId = 'another-survey-stage'; + + const promptConfig: BasePromptConfig = { + type: StageKind.PRIVATE_CHAT, + prompt: [ + { + type: PromptItemType.TEXT, + text: 'Conditional', + condition: createComparisonCondition( + {stageId: mockSurveyStageId, questionId: 'q1'}, + ComparisonOperator.GREATER_THAN, + 5, + ), + }, + ], + }; + + await getFirestoreDataForStructuredPrompt( + mockExperimentId, + mockCohortId, + mockStageId, + mockParticipant, + promptConfig, + ); + + // Should only fetch the survey stage referenced in conditions + expect(firestoreUtils.getFirestoreStage).toHaveBeenCalledTimes(1); + expect(firestoreUtils.getFirestoreStage).toHaveBeenCalledWith( + mockExperimentId, + mockSurveyStageId, + ); + // Should not fetch another-survey-stage + expect(firestoreUtils.getFirestoreStage).not.toHaveBeenCalledWith( + mockExperimentId, + anotherSurveyStageId, + ); + }); + }); }); diff --git a/functions/src/structured_prompt.utils.ts b/functions/src/structured_prompt.utils.ts index 16a333cda..b8bd16aa9 100644 --- a/functions/src/structured_prompt.utils.ts +++ b/functions/src/structured_prompt.utils.ts @@ -19,8 +19,9 @@ import { StageContextPromptItem, StageKind, UserType, - extractConditionDependencies, - evaluateConditionWithStageAnswers, + extractMultipleConditionDependencies, + filterByCondition, + hasAggregationConditions, getAllPrecedingStageIds, getNameFromPublicId, getVariableContext, @@ -97,6 +98,8 @@ export async function getFirestoreDataForStructuredPrompt( // participants whose answers should be used in prompt participants: ParticipantProfileExtended[]; data: Record; + // prompt items filtered by conditions + filteredPromptItems: PromptItem[]; }> { const data: Record = {}; @@ -133,29 +136,80 @@ export async function getFirestoreDataForStructuredPrompt( answerParticipants = activeParticipants; } - for (const item of promptConfig.prompt) { + // Collect all conditions from prompt items and fetch required stage answers + const allConditions = collectPromptItemConditions(promptConfig.prompt); + await fetchConditionStageAnswers( + experimentId, + cohortId, + allConditions, + answerParticipants, + data, + ); + + // Determine condition evaluation context based on stage type + const isGroupChat = promptConfig.type === StageKind.CHAT; + + // For private chat with single participant context, use that participant + // For group chat or mediator context, determine target participant + let conditionParticipant: ParticipantProfileExtended | undefined; + if (userProfile.type === UserType.PARTICIPANT) { + conditionParticipant = answerParticipants[0]; + } else if (answerParticipants.length === 1) { + // Mediator with single participant context (e.g., private chat) + conditionParticipant = answerParticipants[0]; + } + + // Build stage answers for condition evaluation + const singleParticipantAnswers = conditionParticipant + ? buildStageAnswersForParticipant(data, conditionParticipant.publicId) + : {}; + + // For aggregation conditions in group chat, collect all participant answers + const needsAggregation = + isGroupChat && allConditions.some((c) => hasAggregationConditions(c)); + const allParticipantAnswers = needsAggregation + ? buildAllParticipantAnswers(data, answerParticipants) + : undefined; + + // Filter prompt items by conditions before fetching additional data + const visiblePromptItems = filterPromptItemsRecursively( + promptConfig.prompt, + singleParticipantAnswers, + conditionParticipant?.publicId, + allParticipantAnswers, + ); + + // Fetch data only for visible prompt items + for (const item of visiblePromptItems) { await addFirestoreDataForPromptItem( experiment, cohortId, currentStageId, - promptConfig.type, item, activeParticipants, answerParticipants, data, ); } - return {experiment, cohort, participants: answerParticipants, data}; + return { + experiment, + cohort, + participants: answerParticipants, + data, + filteredPromptItems: visiblePromptItems, + }; } /** Populates data object with Firestore documents needed for given single * prompt item. + * + * Note: This function assumes prompt items have already been filtered by + * conditions via filterPromptItemsRecursively() in getFirestoreDataForStructuredPrompt(). */ export async function addFirestoreDataForPromptItem( experiment: Experiment, cohortId: string, currentStageId: string, - stageKind: StageKind, promptItem: PromptItem, // All active participants in cohort activeParticipants: ParticipantProfileExtended[], @@ -163,26 +217,6 @@ export async function addFirestoreDataForPromptItem( answerParticipants: ParticipantProfileExtended[], data: Record = {}, ) { - // Check condition if present - // Conditions are only supported for private chat contexts - if (promptItem.condition && stageKind === StageKind.PRIVATE_CHAT) { - // Lazily fetch any missing stage data needed for condition evaluation - await fetchConditionDependencies( - experiment.id, - cohortId, - promptItem.condition, - answerParticipants, - data, - ); - - // Evaluate condition - skip fetching remaining data if condition not met - if ( - !shouldIncludePromptItem(promptItem, stageKind, answerParticipants, data) - ) { - return; - } - } - // Get profile set ID based on stage ID // (Temporary workaround before profile sets are refactored) const getProfileSetId = (stageId: string) => { @@ -207,7 +241,6 @@ export async function addFirestoreDataForPromptItem( experiment, cohortId, currentStageId, - stageKind, {...promptItem, stageId}, activeParticipants, answerParticipants, @@ -281,7 +314,6 @@ export async function addFirestoreDataForPromptItem( experiment, cohortId, currentStageId, - stageKind, item, activeParticipants, answerParticipants, @@ -329,8 +361,9 @@ export async function getPromptFromConfig( contextParticipantIds, ); + // Use filtered prompt items (conditions already evaluated) const promptText = await processPromptItems( - promptConfig.prompt, + promptData.filteredPromptItems, cohortId, stageId, promptConfig.type, // Pass stageKind to distinguish privateChat from groupChat. @@ -420,17 +453,23 @@ function getProfileInfoForPrompt( } /** - * Lazily fetch stage data needed for condition evaluation. + * Fetch stage data needed for condition evaluation. * Only fetches data for stages not already present in the data object. + * + * @param conditions - All conditions to evaluate (collected from prompt items) */ -async function fetchConditionDependencies( +async function fetchConditionStageAnswers( experimentId: string, cohortId: string, - condition: Condition, + conditions: Condition[], answerParticipants: ParticipantProfileExtended[], data: Record, ): Promise { - const dependencies = extractConditionDependencies(condition); + if (conditions.length === 0) { + return; + } + + const dependencies = extractMultipleConditionDependencies(conditions); const requiredStageIds = [...new Set(dependencies.map((dep) => dep.stageId))]; // Find stages not already in data @@ -477,34 +516,92 @@ function buildStageAnswersForParticipant( } /** - * Evaluate a prompt item's condition for a single participant. - * Returns true if the condition is met (or if there's no condition). - * Only works for private chat contexts with a single participant. + * Collect all conditions from prompt items recursively (including nested groups). */ -function shouldIncludePromptItem( - promptItem: PromptItem, - stageKind: StageKind, - participants: ParticipantProfileExtended[], +function collectPromptItemConditions(items: PromptItem[]): Condition[] { + const conditions: Condition[] = []; + + for (const item of items) { + if (item.condition) { + conditions.push(item.condition); + } + if (item.type === PromptItemType.GROUP) { + conditions.push( + ...collectPromptItemConditions((item as PromptItemGroup).items), + ); + } + } + + return conditions; +} + +/** + * Build stage answers map for ALL participants. + * Used for aggregation conditions that need to evaluate across multiple participants. + */ +function buildAllParticipantAnswers( stageContextData: Record, -): boolean { - if ( - !promptItem.condition || - stageKind !== StageKind.PRIVATE_CHAT || - participants.length !== 1 - ) { - return true; + participants: ParticipantProfileExtended[], +): Record> { + const allAnswers: Record> = {}; + + for (const participant of participants) { + allAnswers[participant.publicId] = buildStageAnswersForParticipant( + stageContextData, + participant.publicId, + ); } - const stageAnswers = buildStageAnswersForParticipant( - stageContextData, - participants[0].publicId, + + return allAnswers; +} + +/** + * Filter prompt items by conditions, recursively handling nested groups. + * + * @param items - The prompt items to filter + * @param stageAnswers - Stage answers for a single participant (used for comparison conditions) + * @param targetParticipantId - For SurveyPerParticipant stages, which participant's answers to use + * @param allParticipantAnswers - For aggregation conditions, all participants' answers + */ +function filterPromptItemsRecursively( + items: PromptItem[], + stageAnswers: Record, + targetParticipantId?: string, + allParticipantAnswers?: Record< + string, + Record + >, +): PromptItem[] { + const visibleItems = filterByCondition( + items, + stageAnswers, + targetParticipantId, + allParticipantAnswers, ); - return evaluateConditionWithStageAnswers(promptItem.condition, stageAnswers); + + return visibleItems.map((item) => { + if (item.type === PromptItemType.GROUP) { + return { + ...item, + items: filterPromptItemsRecursively( + (item as PromptItemGroup).items, + stageAnswers, + targetParticipantId, + allParticipantAnswers, + ), + }; + } + return item; + }); } /** * Process prompt items recursively and return the assembled prompt text. * - * @param promptItems - The list of prompt items to process. + * Note: This function expects promptItems to have already been filtered by + * conditions via filterPromptItemsRecursively() in getFirestoreDataForStructuredPrompt(). + * + * @param promptItems - The list of prompt items to process (already filtered by conditions). * @param cohortId - The cohort ID (used as shuffle seed for GROUP items with cohort-based shuffling). * @param stageId - The current stage ID. Used for: * - Determining preceding stages when STAGE_CONTEXT has empty stageId @@ -518,8 +615,8 @@ function shouldIncludePromptItem( * @param promptData - Pre-fetched data from Firestore: * - experiment: The experiment config (also used for experiment-based shuffle seeding) * - cohort: The cohort config (for cohort-level variables) - * - participants: Participants whose answers are used for STAGE_CONTEXT rendering, - * condition evaluation, and participant-based shuffle seeding + * - participants: Participants whose answers are used for STAGE_CONTEXT rendering + * and participant-based shuffle seeding * - data: Stage context data keyed by stage ID * @param userProfile - The agent's profile (participant or mediator). Used for: * - Determining variable resolution strategy (agent participants use their own variables) @@ -568,19 +665,9 @@ async function processPromptItems( participantForVariables, ); + // Note: promptItems have already been filtered by conditions in + // getFirestoreDataForStructuredPrompt() via filterPromptItemsRecursively() for (const promptItem of promptItems) { - // Check condition if present (only for private chat contexts) - if ( - !shouldIncludePromptItem( - promptItem, - stageKind, - promptData.participants, - promptData.data, - ) - ) { - continue; - } - switch (promptItem.type) { case PromptItemType.TEXT: // Resolve template variables in text prompt items diff --git a/utils/src/stages/survey_stage.ts b/utils/src/stages/survey_stage.ts index 74f19e631..06c3b6925 100644 --- a/utils/src/stages/survey_stage.ts +++ b/utils/src/stages/survey_stage.ts @@ -5,7 +5,10 @@ import { extractConditionDependencies, extractMultipleConditionDependencies, } from '../utils/condition'; -import {getConditionDependencyValuesWithCurrentStage} from '../utils/condition.utils'; +import { + buildAllValuesWithCurrentStage, + getConditionDependencyValuesWithCurrentStage, +} from '../utils/condition.utils'; import { BaseStageConfig, BaseStageParticipantAnswer, @@ -318,27 +321,42 @@ export function getVisibleSurveyQuestions( currentStageId: string, currentStageAnswers: Record, allStageAnswers?: Record, // Map of stageId to answer data - targetParticipantId?: string, // For survey-per-participant: which participant is being evaluated + currentParticipantId?: string, // The current participant taking the survey + allParticipantAnswers?: Record< + string, + Record + >, // All participants' answers ): SurveyQuestion[] { // Extract all dependencies from all question conditions - const allDependencies = extractMultipleConditionDependencies( - questions.map((q) => q.condition), - ); + const allConditions = questions + .map((q) => q.condition) + .filter((c): c is Condition => c !== undefined); + const allDependencies = extractMultipleConditionDependencies(allConditions); - // Get the actual values for all condition targets + // Get target values for comparison conditions (single participant's answers) const targetValues = getConditionDependencyValuesWithCurrentStage( allDependencies, currentStageId, currentStageAnswers, allStageAnswers, - targetParticipantId, ); + // Build aggregated values for aggregation conditions (all participants' answers) + const allValues = allParticipantAnswers + ? buildAllValuesWithCurrentStage( + allDependencies, + currentStageId, + currentStageAnswers, + currentParticipantId, + allParticipantAnswers, + ) + : undefined; + return questions.filter((question) => { if (!question.condition) { return true; // No condition means always show } - return evaluateCondition(question.condition, targetValues); + return evaluateCondition(question.condition, targetValues, allValues); }); } @@ -348,7 +366,11 @@ export function isQuestionVisible( currentStageId: string, currentStageAnswers: Record, allStageAnswers?: Record, - targetParticipantId?: string, // For survey-per-participant: which participant is being evaluated + currentParticipantId?: string, // The current participant taking the survey + allParticipantAnswers?: Record< + string, + Record + >, // All participants' answers ): boolean { if (!question.condition) { return true; // No condition means always show @@ -357,16 +379,26 @@ export function isQuestionVisible( // Extract only this question's dependencies const dependencies = extractConditionDependencies(question.condition); - // Get the actual values for this question's condition targets + // Get target values for comparison conditions (single participant's answers) const targetValues = getConditionDependencyValuesWithCurrentStage( dependencies, currentStageId, currentStageAnswers, allStageAnswers, - targetParticipantId, ); - return evaluateCondition(question.condition, targetValues); + // Build aggregated values for aggregation conditions (all participants' answers) + const allValues = allParticipantAnswers + ? buildAllValuesWithCurrentStage( + dependencies, + currentStageId, + currentStageAnswers, + currentParticipantId, + allParticipantAnswers, + ) + : undefined; + + return evaluateCondition(question.condition, targetValues, allValues); } /** Extract the value from a survey answer */ diff --git a/utils/src/utils/condition.test.ts b/utils/src/utils/condition.test.ts index 75854642b..68e2e6680 100644 --- a/utils/src/utils/condition.test.ts +++ b/utils/src/utils/condition.test.ts @@ -1,11 +1,14 @@ import { ConditionOperator, ComparisonOperator, + AggregationOperator, ConditionGroup, ComparisonCondition, + AggregationCondition, ConditionTargetReference, createConditionGroup, createComparisonCondition, + createAggregationCondition, evaluateCondition, getComparisonOperatorLabel, getConditionOperatorLabel, @@ -13,6 +16,7 @@ import { extractMultipleConditionDependencies, getConditionTargetKey, parseConditionTargetKey, + hasAggregationConditions, } from './condition'; // Mock generateId for predictable test results @@ -127,7 +131,47 @@ describe('condition utilities', () => { }); }); + describe('createAggregationCondition', () => { + test('creates aggregation condition with default values', () => { + const target: ConditionTargetReference = { + stageId: 'stage1', + questionId: 'q1', + }; + const condition = createAggregationCondition(target); + expect(condition).toEqual({ + id: 'test-id', + type: 'aggregation', + target, + aggregator: AggregationOperator.ANY, + operator: ComparisonOperator.EQUALS, + value: '', + }); + }); + + test('creates aggregation condition with specified values', () => { + const target: ConditionTargetReference = { + stageId: 'stage2', + questionId: 'q2', + }; + const condition = createAggregationCondition( + target, + AggregationOperator.COUNT, + ComparisonOperator.GREATER_THAN_OR_EQUAL, + 3, + ); + expect(condition).toEqual({ + id: 'test-id', + type: 'aggregation', + target, + aggregator: AggregationOperator.COUNT, + operator: ComparisonOperator.GREATER_THAN_OR_EQUAL, + value: 3, + }); + }); + }); + describe('evaluateCondition', () => { + // Flat target values for comparison tests const targetValues = { 'stage1::q1': 'apple', 'stage1::q2': 5, @@ -446,6 +490,442 @@ describe('condition utilities', () => { expect(evaluateCondition(nestedGroup, targetValues)).toBe(true); }); }); + + describe('aggregation conditions', () => { + // Aggregated values: arrays of values for each target key + // Represents 4 participants with different values + const allValues: Record = { + 'stage1::q1': ['yes', 'yes', 'no', 'yes'], // 3 yes, 1 no + 'stage1::q2': [5, 10, 15, 20], // sum=50, avg=12.5 + 'stage1::q3': ['a', 'b', 'c'], // only 3 values (p4 didn't answer) + }; + + // Empty values for testing empty scenarios + const emptyValues: Record = {}; + + describe('ANY aggregator', () => { + test('returns true if any value matches', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.ANY, + operator: ComparisonOperator.EQUALS, + value: 'yes', + }; + // 3 out of 4 values are 'yes' + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + }); + + test('returns false if no value matches', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.ANY, + operator: ComparisonOperator.EQUALS, + value: 'maybe', + }; + expect(evaluateCondition(condition, {}, allValues)).toBe(false); + }); + + test('returns false for empty values', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.ANY, + operator: ComparisonOperator.EQUALS, + value: 'test', + }; + expect(evaluateCondition(condition, {}, emptyValues)).toBe(false); + }); + + test('works with numeric comparisons', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.ANY, + operator: ComparisonOperator.GREATER_THAN, + value: 15, + }; + // value 20 is > 15 + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + + condition.value = 25; + expect(evaluateCondition(condition, {}, allValues)).toBe(false); + }); + }); + + describe('ALL aggregator', () => { + test('returns true if all values match', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.ALL, + operator: ComparisonOperator.GREATER_THAN, + value: 0, + }; + // All values (5, 10, 15, 20) are > 0 + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + }); + + test('returns false if any value does not match', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.ALL, + operator: ComparisonOperator.EQUALS, + value: 'yes', + }; + // one value is 'no', so not all are 'yes' + expect(evaluateCondition(condition, {}, allValues)).toBe(false); + }); + + test('returns false for empty values', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.ALL, + operator: ComparisonOperator.EQUALS, + value: 'test', + }; + expect(evaluateCondition(condition, {}, emptyValues)).toBe(false); + }); + }); + + describe('NONE aggregator', () => { + test('returns true if no value matches', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.NONE, + operator: ComparisonOperator.EQUALS, + value: 'maybe', + }; + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + }); + + test('returns false if any value matches', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.NONE, + operator: ComparisonOperator.EQUALS, + value: 'yes', + }; + expect(evaluateCondition(condition, {}, allValues)).toBe(false); + }); + + test('returns false for empty values', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.NONE, + operator: ComparisonOperator.EQUALS, + value: 'test', + }; + expect(evaluateCondition(condition, {}, emptyValues)).toBe(false); + }); + }); + + describe('COUNT aggregator', () => { + test('counts all non-null values without filterComparison', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.COUNT, + operator: ComparisonOperator.EQUALS, + value: 4, + }; + // All 4 values for q1 + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + + condition.value = 3; + expect(evaluateCondition(condition, {}, allValues)).toBe(false); + }); + + test('counts filtered values with filterComparison', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.COUNT, + operator: ComparisonOperator.EQUALS, + value: 3, + filterComparison: { + operator: ComparisonOperator.EQUALS, + value: 'yes', + }, + }; + // 3 values are 'yes' + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + }); + + test('count with GREATER_THAN_OR_EQUAL comparison', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.COUNT, + operator: ComparisonOperator.GREATER_THAN_OR_EQUAL, + value: 3, + filterComparison: { + operator: ComparisonOperator.EQUALS, + value: 'yes', + }, + }; + // 3 values are 'yes', >= 3 + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + + condition.value = 4; + expect(evaluateCondition(condition, {}, allValues)).toBe(false); + }); + + test('count with numeric filterComparison', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.COUNT, + operator: ComparisonOperator.EQUALS, + value: 2, + filterComparison: { + operator: ComparisonOperator.GREATER_THAN, + value: 10, + }, + }; + // Values > 10: 15, 20 = 2 values + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + }); + + test('returns false for empty values', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.COUNT, + operator: ComparisonOperator.EQUALS, + value: 0, + }; + // No values to count + expect(evaluateCondition(condition, {}, emptyValues)).toBe(false); + }); + }); + + describe('SUM aggregator', () => { + test('sums all values without filterComparison', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.SUM, + operator: ComparisonOperator.EQUALS, + value: 50, // 5 + 10 + 15 + 20 = 50 + }; + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + }); + + test('sums filtered values with filterComparison', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.SUM, + operator: ComparisonOperator.EQUALS, + value: 35, // 15 + 20 = 35 (values > 10) + filterComparison: { + operator: ComparisonOperator.GREATER_THAN, + value: 10, + }, + }; + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + }); + + test('sum with GREATER_THAN comparison', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.SUM, + operator: ComparisonOperator.GREATER_THAN, + value: 40, + }; + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + + condition.value = 50; + expect(evaluateCondition(condition, {}, allValues)).toBe(false); + }); + + test('returns false for empty values', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.SUM, + operator: ComparisonOperator.EQUALS, + value: 0, + }; + expect(evaluateCondition(condition, {}, emptyValues)).toBe(false); + }); + }); + + describe('AVERAGE aggregator', () => { + test('averages all values without filterComparison', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.AVERAGE, + operator: ComparisonOperator.EQUALS, + value: 12.5, // (5 + 10 + 15 + 20) / 4 = 12.5 + }; + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + }); + + test('averages filtered values with filterComparison', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.AVERAGE, + operator: ComparisonOperator.EQUALS, + value: 17.5, // (15 + 20) / 2 = 17.5 (values > 10) + filterComparison: { + operator: ComparisonOperator.GREATER_THAN, + value: 10, + }, + }; + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + }); + + test('average with GREATER_THAN_OR_EQUAL comparison', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.AVERAGE, + operator: ComparisonOperator.GREATER_THAN_OR_EQUAL, + value: 12.5, + }; + expect(evaluateCondition(condition, {}, allValues)).toBe(true); + + condition.value = 13; + expect(evaluateCondition(condition, {}, allValues)).toBe(false); + }); + + test('returns false for empty values', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.AVERAGE, + operator: ComparisonOperator.EQUALS, + value: 0, + }; + expect(evaluateCondition(condition, {}, emptyValues)).toBe(false); + }); + + test('returns false when filter leaves no values', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.AVERAGE, + operator: ComparisonOperator.EQUALS, + value: 0, + filterComparison: { + operator: ComparisonOperator.GREATER_THAN, + value: 100, // No values > 100 + }, + }; + expect(evaluateCondition(condition, {}, allValues)).toBe(false); + }); + }); + + describe('aggregation with missing target', () => { + test('returns false for non-existent target', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'nonexistent', questionId: 'q1'}, + aggregator: AggregationOperator.ANY, + operator: ComparisonOperator.EQUALS, + value: 'test', + }; + expect(evaluateCondition(condition, {}, allValues)).toBe(false); + }); + }); + + describe('aggregation in condition groups', () => { + test('combines aggregation conditions', () => { + const group: ConditionGroup = { + id: '1', + type: 'group', + operator: ConditionOperator.AND, + conditions: [ + { + id: '2', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.COUNT, + operator: ComparisonOperator.GREATER_THAN_OR_EQUAL, + value: 3, + filterComparison: { + operator: ComparisonOperator.EQUALS, + value: 'yes', + }, + }, + { + id: '3', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.AVERAGE, + operator: ComparisonOperator.GREATER_THAN, + value: 10, + }, + ], + }; + // 3 'yes' values >= 3, and average 12.5 > 10 + expect(evaluateCondition(group, {}, allValues)).toBe(true); + }); + + test('OR group with aggregation conditions', () => { + const group: ConditionGroup = { + id: '1', + type: 'group', + operator: ConditionOperator.OR, + conditions: [ + { + id: '2', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.ALL, + operator: ComparisonOperator.EQUALS, + value: 'yes', // false - not all are 'yes' + }, + { + id: '3', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q2'}, + aggregator: AggregationOperator.SUM, + operator: ComparisonOperator.EQUALS, + value: 50, // true - sum is 50 + }, + ], + }; + expect(evaluateCondition(group, {}, allValues)).toBe(true); + }); + }); + }); }); describe('getComparisonOperatorLabel', () => { @@ -514,6 +994,20 @@ describe('condition utilities', () => { ]); }); + test('extracts single dependency from aggregation condition', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.COUNT, + operator: ComparisonOperator.GREATER_THAN, + value: 3, + }; + expect(extractConditionDependencies(condition)).toEqual([ + {stageId: 'stage1', questionId: 'q1'}, + ]); + }); + test('extracts dependencies from condition group', () => { const group: ConditionGroup = { id: '1', @@ -681,6 +1175,129 @@ describe('condition utilities', () => { }); }); + describe('hasAggregationConditions', () => { + test('returns false for undefined condition', () => { + expect(hasAggregationConditions(undefined)).toBe(false); + }); + + test('returns false for comparison condition', () => { + const condition: ComparisonCondition = { + id: '1', + type: 'comparison', + target: {stageId: 'stage1', questionId: 'q1'}, + operator: ComparisonOperator.EQUALS, + value: 'test', + }; + expect(hasAggregationConditions(condition)).toBe(false); + }); + + test('returns true for aggregation condition', () => { + const condition: AggregationCondition = { + id: '1', + type: 'aggregation', + target: {stageId: 'stage1', questionId: 'q1'}, + aggregator: AggregationOperator.COUNT, + operator: ComparisonOperator.GREATER_THAN, + value: 3, + }; + expect(hasAggregationConditions(condition)).toBe(true); + }); + + test('returns false for group with only comparison conditions', () => { + const group: ConditionGroup = { + id: '1', + type: 'group', + operator: ConditionOperator.AND, + conditions: [ + { + id: '2', + type: 'comparison', + target: {stageId: 'stage1', questionId: 'q1'}, + operator: ComparisonOperator.EQUALS, + value: 'test', + }, + { + id: '3', + type: 'comparison', + target: {stageId: 'stage2', questionId: 'q2'}, + operator: ComparisonOperator.EQUALS, + value: 'test', + }, + ], + }; + expect(hasAggregationConditions(group)).toBe(false); + }); + + test('returns true for group containing aggregation condition', () => { + const group: ConditionGroup = { + id: '1', + type: 'group', + operator: ConditionOperator.AND, + conditions: [ + { + id: '2', + type: 'comparison', + target: {stageId: 'stage1', questionId: 'q1'}, + operator: ComparisonOperator.EQUALS, + value: 'test', + }, + { + id: '3', + type: 'aggregation', + target: {stageId: 'stage2', questionId: 'q2'}, + aggregator: AggregationOperator.ANY, + operator: ComparisonOperator.EQUALS, + value: 'test', + }, + ], + }; + expect(hasAggregationConditions(group)).toBe(true); + }); + + test('returns true for nested group containing aggregation condition', () => { + const nestedGroup: ConditionGroup = { + id: '1', + type: 'group', + operator: ConditionOperator.AND, + conditions: [ + { + id: '2', + type: 'group', + operator: ConditionOperator.OR, + conditions: [ + { + id: '3', + type: 'comparison', + target: {stageId: 'stage1', questionId: 'q1'}, + operator: ComparisonOperator.EQUALS, + value: 'test', + }, + { + id: '4', + type: 'aggregation', + target: {stageId: 'stage2', questionId: 'q2'}, + aggregator: AggregationOperator.SUM, + operator: ComparisonOperator.GREATER_THAN, + value: 100, + }, + ], + }, + ], + }; + expect(hasAggregationConditions(nestedGroup)).toBe(true); + }); + + test('returns false for empty group', () => { + const group: ConditionGroup = { + id: '1', + type: 'group', + operator: ConditionOperator.AND, + conditions: [], + }; + expect(hasAggregationConditions(group)).toBe(false); + }); + }); + describe('getConditionTargetKey', () => { test('builds key from stage and question IDs', () => { const target: ConditionTargetReference = { diff --git a/utils/src/utils/condition.ts b/utils/src/utils/condition.ts index 150a3a0f2..581764744 100644 --- a/utils/src/utils/condition.ts +++ b/utils/src/utils/condition.ts @@ -18,9 +18,19 @@ export enum ComparisonOperator { NOT_CONTAINS = 'not_contains', } +/** Aggregation operators for evaluating conditions across multiple values */ +export enum AggregationOperator { + ANY = 'any', // True if any value passes the comparison + ALL = 'all', // True if all values pass the comparison + NONE = 'none', // True if no value passes the comparison + COUNT = 'count', // Count of values (optionally filtered), comparison applied to count + SUM = 'sum', // Sum of numeric values (optionally filtered), comparison applied to result + AVERAGE = 'average', // Average of numeric values (optionally filtered), comparison applied to result +} + export interface BaseCondition { id: string; - type: 'group' | 'comparison'; + type: 'group' | 'comparison' | 'aggregation'; } export interface ConditionGroup extends BaseCondition { @@ -35,14 +45,52 @@ export interface ConditionTargetReference { questionId: string; } -export interface ComparisonCondition extends BaseCondition { - type: 'comparison'; - target: ConditionTargetReference; // Structured reference to the target +/** + * Reusable comparison specification (operator + value pair). + * Used for filter comparisons and as a building block for condition types. + */ +export interface ComparisonSpec { operator: ComparisonOperator; - value: string | number | boolean; // The value to compare against + value: string | number | boolean; +} + +/** + * Comparison spec with a target reference. + * Used as the base for both ComparisonCondition and AggregationCondition. + */ +export interface TargetedComparisonSpec extends ComparisonSpec { + target: ConditionTargetReference; +} + +/** + * A condition that compares a single target value against an expected value. + */ +export interface ComparisonCondition + extends BaseCondition, TargetedComparisonSpec { + type: 'comparison'; +} + +/** + * Aggregation condition for evaluating across multiple values. + * Used when a target can have multiple values (e.g., answers from multiple participants). + * + * For ANY/ALL/NONE: applies operator/value to each value (filterComparison ignored) + * For COUNT/SUM/AVERAGE: + * - filterComparison (optional): filters which values to include in aggregation + * - operator/value: compares the aggregated result + */ +export interface AggregationCondition + extends BaseCondition, TargetedComparisonSpec { + type: 'aggregation'; + aggregator: AggregationOperator; + // Optional: filter values before aggregating (only used for COUNT/SUM/AVERAGE) + filterComparison?: ComparisonSpec; } -export type Condition = ConditionGroup | ComparisonCondition; +export type Condition = + | ConditionGroup + | ComparisonCondition + | AggregationCondition; /** Create a condition group with optional initial conditions */ export function createConditionGroup( @@ -72,15 +120,44 @@ export function createComparisonCondition( }; } -/** Evaluate a condition against target values */ +/** Create an aggregation condition for group contexts */ +export function createAggregationCondition( + target: ConditionTargetReference, + aggregator: AggregationOperator = AggregationOperator.ANY, + operator: ComparisonOperator = ComparisonOperator.EQUALS, + value: string | number | boolean = '', + filterComparison?: ComparisonSpec, +): AggregationCondition { + return { + id: generateId(), + type: 'aggregation', + target, + aggregator, + operator, + value, + filterComparison, + }; +} + +/** + * Evaluate a condition against target values. + * + * @param condition - The condition to evaluate + * @param targetValues - Map of target keys to values (for comparison conditions) + * @param allValues - Map of target keys to arrays of values (for aggregation conditions). + * Only needed if the condition contains aggregation conditions. + */ export function evaluateCondition( condition: Condition | undefined, targetValues: Record, + allValues?: Record, ): boolean { if (!condition) return true; if (condition.type === 'group') { - return evaluateConditionGroup(condition, targetValues); + return evaluateConditionGroup(condition, targetValues, allValues); + } else if (condition.type === 'aggregation') { + return evaluateAggregationCondition(condition, allValues ?? {}); } else { return evaluateComparisonCondition(condition, targetValues); } @@ -89,42 +166,113 @@ export function evaluateCondition( function evaluateConditionGroup( group: ConditionGroup, targetValues: Record, + allValues?: Record, ): boolean { if (group.conditions.length === 0) return true; if (group.operator === ConditionOperator.AND) { - return group.conditions.every((c) => evaluateCondition(c, targetValues)); + return group.conditions.every((c) => + evaluateCondition(c, targetValues, allValues), + ); } else { - return group.conditions.some((c) => evaluateCondition(c, targetValues)); + return group.conditions.some((c) => + evaluateCondition(c, targetValues, allValues), + ); } } -function evaluateComparisonCondition( - condition: ComparisonCondition, - targetValues: Record, +/** Apply a comparison operator to two values */ +function applyComparison( + operator: ComparisonOperator, + value: unknown, + compareValue: string | number | boolean, ): boolean { - const targetKey = getConditionTargetKey(condition.target); - const targetValue = targetValues[targetKey]; - - if (targetValue === undefined) return false; - - switch (condition.operator) { + switch (operator) { case ComparisonOperator.EQUALS: - return targetValue === condition.value; + return value === compareValue; case ComparisonOperator.NOT_EQUALS: - return targetValue !== condition.value; + return value !== compareValue; case ComparisonOperator.GREATER_THAN: - return Number(targetValue) > Number(condition.value); + return Number(value) > Number(compareValue); case ComparisonOperator.GREATER_THAN_OR_EQUAL: - return Number(targetValue) >= Number(condition.value); + return Number(value) >= Number(compareValue); case ComparisonOperator.LESS_THAN: - return Number(targetValue) < Number(condition.value); + return Number(value) < Number(compareValue); case ComparisonOperator.LESS_THAN_OR_EQUAL: - return Number(targetValue) <= Number(condition.value); + return Number(value) <= Number(compareValue); case ComparisonOperator.CONTAINS: - return String(targetValue).includes(String(condition.value)); + return String(value).includes(String(compareValue)); case ComparisonOperator.NOT_CONTAINS: - return !String(targetValue).includes(String(condition.value)); + return !String(value).includes(String(compareValue)); + default: + return false; + } +} + +function evaluateComparisonCondition( + condition: ComparisonCondition, + targetValues: Record, +): boolean { + const targetKey = getConditionTargetKey(condition.target); + const targetValue = targetValues[targetKey]; + + if (targetValue === undefined) return false; + + return applyComparison(condition.operator, targetValue, condition.value); +} + +/** Filter values by a comparison spec */ +function filterValues(values: unknown[], spec: ComparisonSpec): unknown[] { + return values.filter((v) => applyComparison(spec.operator, v, spec.value)); +} + +/** Sum numeric values */ +function sumValues(values: unknown[]): number { + return values.reduce((acc: number, v) => acc + (Number(v) || 0), 0); +} + +/** + * Evaluate an aggregation condition against an array of values. + */ +function evaluateAggregationCondition( + condition: AggregationCondition, + allValues: Record, +): boolean { + const targetKey = getConditionTargetKey(condition.target); + const values = allValues[targetKey] ?? []; + + if (values.length === 0) return false; + + // Helper to compare result against condition's operator/value + const compareResult = (result: unknown) => + applyComparison(condition.operator, result, condition.value); + + // Helper to get filtered values for COUNT/SUM/AVERAGE + const getFiltered = () => + condition.filterComparison + ? filterValues(values, condition.filterComparison) + : values; + + switch (condition.aggregator) { + case AggregationOperator.ANY: + return values.some(compareResult); + case AggregationOperator.ALL: + return values.every(compareResult); + case AggregationOperator.NONE: + return !values.some(compareResult); + case AggregationOperator.COUNT: { + const filtered = condition.filterComparison + ? filterValues(values, condition.filterComparison) + : values.filter((v) => v !== undefined && v !== null); + return compareResult(filtered.length); + } + case AggregationOperator.SUM: + return compareResult(sumValues(getFiltered())); + case AggregationOperator.AVERAGE: { + const filtered = getFiltered(); + if (filtered.length === 0) return false; + return compareResult(sumValues(filtered) / filtered.length); + } default: return false; } @@ -168,6 +316,28 @@ export function getConditionOperatorLabel(operator: ConditionOperator): string { } } +/** Get human-readable label for aggregation operator */ +export function getAggregationOperatorLabel( + operator: AggregationOperator, +): string { + switch (operator) { + case AggregationOperator.ANY: + return 'ANY value'; + case AggregationOperator.ALL: + return 'ALL values'; + case AggregationOperator.NONE: + return 'NO value'; + case AggregationOperator.COUNT: + return 'COUNT of values'; + case AggregationOperator.SUM: + return 'SUM of values'; + case AggregationOperator.AVERAGE: + return 'AVERAGE of values'; + default: + return operator; + } +} + /** Build the key string for a condition target reference */ export function getConditionTargetKey( target: ConditionTargetReference, @@ -199,6 +369,8 @@ export function extractConditionDependencies( if (condition.type === 'comparison') { dependencies.push(condition.target); + } else if (condition.type === 'aggregation') { + dependencies.push(condition.target); } else if (condition.type === 'group') { for (const subCondition of condition.conditions) { dependencies.push(...extractConditionDependencies(subCondition)); @@ -208,6 +380,20 @@ export function extractConditionDependencies( return deduplicateTargetReferences(dependencies); } +/** Check if a condition contains any aggregation conditions */ +export function hasAggregationConditions( + condition: Condition | undefined, +): boolean { + if (!condition) return false; + + if (condition.type === 'aggregation') { + return true; + } else if (condition.type === 'group') { + return condition.conditions.some((c) => hasAggregationConditions(c)); + } + return false; +} + /** Extract dependencies from multiple conditions */ export function extractMultipleConditionDependencies( conditions: (Condition | undefined)[], diff --git a/utils/src/utils/condition.utils.ts b/utils/src/utils/condition.utils.ts index 831438675..b20ce36cd 100644 --- a/utils/src/utils/condition.utils.ts +++ b/utils/src/utils/condition.utils.ts @@ -77,6 +77,53 @@ export function getConditionDependencyValues( return values; } +/** + * Build aggregated values for condition evaluation. + * + * Returns a map where each target key maps to an array of all values from all participants. + * This structure is used for aggregation conditions (ANY, ALL, COUNT, etc.). + * + * @param dependencies - The condition target references to resolve + * @param allParticipantAnswers - Map of participantId to their stage answers + * @returns Map of targetKey to array of values from all participants + */ +export function buildAllValues( + dependencies: ConditionTargetReference[], + allParticipantAnswers: Record>, +): Record { + const result: Record = {}; + + // Initialize arrays for each target + for (const targetRef of dependencies) { + const dataKey = getConditionTargetKey(targetRef); + result[dataKey] = []; + } + + // Collect values from all participants + for (const stageAnswers of Object.values(allParticipantAnswers)) { + for (const targetRef of dependencies) { + const dataKey = getConditionTargetKey(targetRef); + const stageAnswer = stageAnswers[targetRef.stageId]; + + if (!stageAnswer || !('answerMap' in stageAnswer)) { + continue; + } + + if (stageAnswer.kind === StageKind.SURVEY) { + const surveyAnswer = stageAnswer as SurveyStageParticipantAnswer; + const answer = surveyAnswer.answerMap[targetRef.questionId]; + if (answer) { + result[dataKey].push(extractAnswerValue(answer)); + } + } + // Note: SurveyPerParticipant stages are more complex for aggregation + // as they have answers per-participant already. For now, we skip these. + } + } + + return result; +} + /** * Get condition dependency values, including current stage answers that may not be persisted yet. * @@ -121,6 +168,59 @@ export function getConditionDependencyValuesWithCurrentStage( return values; } +/** + * Build aggregated values, including current stage answers that may not be persisted yet. + * + * This is used during active survey completion where we need to: + * 1. Include all participants' persisted answers (for aggregation conditions) + * 2. Merge in the current participant's unsaved answers for the current stage + * + * @param dependencies - The condition target references to resolve + * @param currentStageId - The ID of the stage currently being worked on + * @param currentStageAnswers - Map of questionId to SurveyAnswer for the current stage (unsaved) + * @param currentParticipantId - The current participant's ID (to include their unsaved answers) + * @param allParticipantAnswers - Map of participantId to their stage answers (for aggregation) + * @returns Map of targetKey to array of values + */ +export function buildAllValuesWithCurrentStage( + dependencies: ConditionTargetReference[], + currentStageId: string, + currentStageAnswers: Record, + currentParticipantId?: string, + allParticipantAnswers?: Record< + string, + Record + >, +): Record { + // Start with all participants' persisted values + const result: Record = allParticipantAnswers + ? buildAllValues(dependencies, allParticipantAnswers) + : {}; + + // Initialize arrays for any missing targets + for (const targetRef of dependencies) { + const dataKey = getConditionTargetKey(targetRef); + if (!result[dataKey]) { + result[dataKey] = []; + } + } + + // Add current participant's unsaved answers for the current stage + if (currentParticipantId) { + for (const targetRef of dependencies) { + if (targetRef.stageId === currentStageId) { + const dataKey = getConditionTargetKey(targetRef); + const answer = currentStageAnswers[targetRef.questionId]; + if (answer) { + result[dataKey].push(extractAnswerValue(answer)); + } + } + } + } + + return result; +} + /** * Evaluate a condition against stage answers. * @@ -128,25 +228,37 @@ export function getConditionDependencyValuesWithCurrentStage( * and condition evaluation in one call. * * @param condition - The condition to evaluate - * @param stageAnswers - Map of stageId to StageParticipantAnswer + * @param stageAnswers - Map of stageId to StageParticipantAnswer (for current participant) * @param targetParticipantId - For SurveyPerParticipant stages, which participant's answers to use + * @param allParticipantAnswers - Map of participantId to their stage answers (for aggregation) * @returns true if condition passes (or if condition is undefined) */ export function evaluateConditionWithStageAnswers( condition: Condition | undefined, stageAnswers: Record, targetParticipantId?: string, + allParticipantAnswers?: Record< + string, + Record + >, ): boolean { if (!condition) return true; const dependencies = extractMultipleConditionDependencies([condition]); + + // Get target values for comparison conditions const targetValues = getConditionDependencyValues( dependencies, stageAnswers, targetParticipantId, ); - return evaluateCondition(condition, targetValues); + // Build aggregated values for aggregation conditions (if multi-participant data provided) + const allValues = allParticipantAnswers + ? buildAllValues(dependencies, allParticipantAnswers) + : undefined; + + return evaluateCondition(condition, targetValues, allValues); } /** @@ -155,14 +267,19 @@ export function evaluateConditionWithStageAnswers( * Generic utility to filter any array of items that have optional conditions. * * @param items - Array of items, each potentially having a `condition` property - * @param stageAnswers - Map of stageId to StageParticipantAnswer + * @param stageAnswers - Map of stageId to StageParticipantAnswer (for current participant) * @param targetParticipantId - For SurveyPerParticipant stages, which participant's answers to use + * @param allParticipantAnswers - Map of participantId to their stage answers (for aggregation) * @returns Filtered array containing only items whose conditions pass */ export function filterByCondition( items: T[], stageAnswers: Record, targetParticipantId?: string, + allParticipantAnswers?: Record< + string, + Record + >, ): T[] { // Extract all dependencies upfront for efficiency const allConditions = items @@ -174,15 +291,22 @@ export function filterByCondition( } const allDependencies = extractMultipleConditionDependencies(allConditions); + + // Get target values for comparison conditions const targetValues = getConditionDependencyValues( allDependencies, stageAnswers, targetParticipantId, ); + // Build aggregated values for aggregation conditions (if multi-participant data provided) + const allValues = allParticipantAnswers + ? buildAllValues(allDependencies, allParticipantAnswers) + : undefined; + return items.filter((item) => { if (!item.condition) return true; - return evaluateCondition(item.condition, targetValues); + return evaluateCondition(item.condition, targetValues, allValues); }); }