Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion programs/predicate_registry/src/instructions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ pub struct SetPolicyId<'info> {
/// Updates an existing policy for a PROGRAM. Only the program's upgrade
/// authority can call this instruction.
#[derive(Accounts)]
#[instruction(client_program: Pubkey)]
pub struct UpdatePolicyId<'info> {
/// The registry account (for event emission)
#[account(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use crate::errors::PredicateRegistryError;
///
/// # Arguments
/// * `ctx` - The instruction context containing accounts
/// * `client_program` - The program address that this policy applies to
/// * `policy_id` - The new policy ID string to set
///
/// # Returns
Expand All @@ -23,7 +22,6 @@ use crate::errors::PredicateRegistryError;
/// - Policy PDA is derived from the program address, not the user
pub fn update_policy_id(
ctx: Context<UpdatePolicyId>,
client_program: Pubkey,
policy_id: String
) -> Result<()> {
require!(!policy_id.is_empty(), PredicateRegistryError::InvalidPolicyId);
Expand Down Expand Up @@ -70,10 +68,10 @@ pub fn update_policy_id(
let policy_account = &mut ctx.accounts.policy_account;
let clock = Clock::get()?;

let client_program = ctx.accounts.client_program.key();
let previous_policy_id = policy_account.policy_id.clone();
policy_account.update_policy_id(policy_id.clone(), &clock)?;

// Update registry timestamp
registry.updated_at = clock.unix_timestamp;

emit!(PolicyUpdated {
Expand Down
3 changes: 1 addition & 2 deletions programs/predicate_registry/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,9 @@ pub mod predicate_registry {
/// * `InvalidClientProgram` - If program doesn't match policy account
pub fn update_policy_id(
ctx: Context<UpdatePolicyId>,
client_program: Pubkey,
policy_id: String
) -> Result<()> {
instructions::update_policy_id(ctx, client_program, policy_id)
instructions::update_policy_id(ctx, policy_id)
}

/// Validate an attestation for a transaction
Expand Down
2 changes: 1 addition & 1 deletion scripts/set-customer-policy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async function setCustomerPolicy(customerProgramId: string, policyId: string) {

console.log("\n📝 Updating policy...");
const tx = await program.methods
.updatePolicyId(customerProgram, policyId)
.updatePolicyId(policyId)
.accounts({
registry: registryPda,
policyAccount: policyPda,
Expand Down
48 changes: 46 additions & 2 deletions tests/functional/policy-management.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,50 @@ describe("Policy Management", () => {
expect(policyAfter.setAt.toNumber()).to.equal(setAtBefore); // Should not change
});

it("Should emit PolicyUpdated event with correct client_program", async () => {
const [policyPda] = findPolicyPDA(
counterProgramId,
context.program.programId
);
const policyBefore = await context.program.account.policyAccount.fetch(
policyPda
);
const previousPolicyId = policyBefore.policyId;
let eventReceived = false;

const listener = context.program.addEventListener(
"policyUpdated",
(event: any) => {
expect(event.registry.toString()).to.equal(
context.registry.registryPda.toString()
);
expect(event.clientProgram.toString()).to.equal(
counterProgramId.toString()
);
expect(event.authority.toString()).to.equal(
context.authority.keypair.publicKey.toString()
);
expect(event.previousPolicyId).to.equal(previousPolicyId);
expect(event.newPolicyId).to.equal(updatedPolicyId);
expect(event.timestamp.toNumber()).to.be.greaterThan(0);
eventReceived = true;
}
);

await updatePolicyId(
context.program,
counterProgramId,
context.authority.keypair,
updatedPolicyId,
context.registry.registryPda
);

await new Promise((resolve) => setTimeout(resolve, 100));
expect(eventReceived).to.be.true;

await context.program.removeEventListener(listener);
});

it("Should NOT increment registry policy count when updating existing policy", async () => {
const [policyPda] = findPolicyPDA(
counterProgramId,
Expand Down Expand Up @@ -252,7 +296,7 @@ describe("Policy Management", () => {

try {
await context.program.methods
.updatePolicyId(counterProgramId, updatedPolicyId)
.updatePolicyId(updatedPolicyId)
.accounts({
registry: context.registry.registryPda,
policyAccount: policyPda,
Expand Down Expand Up @@ -285,7 +329,7 @@ describe("Policy Management", () => {

try {
await context.program.methods
.updatePolicyId(programWithoutPolicy, updatedPolicyId)
.updatePolicyId(updatedPolicyId)
.accounts({
registry: context.registry.registryPda,
policyAccount: policyPda,
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers/test-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ export async function updatePolicyId(
);

return await program.methods
.updatePolicyId(clientProgram, policyId)
.updatePolicyId(policyId)
.accounts({
registry: registryPda,
policyAccount: policyPda,
Expand Down
10 changes: 8 additions & 2 deletions tests/security/uuid-replay-via-cleanup.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,22 @@ describe("UUID Replay Prevention via Cleanup", () => {
* Helper function to create message hash for signing
*/
function createMessageHash(statement: any): Buffer {
// Hash variable-length fields separately to prevent collisions
const encodedSigAndArgsHash = crypto.createHash("sha256").update(statement.encodedSigAndArgs).digest();
const policyIdHash = crypto.createHash("sha256").update(Buffer.from(statement.policyId, "utf8")).digest();

// Concatenate fixed-length fields with hashed variable-length fields
const data = Buffer.concat([
Buffer.from(statement.uuid),
statement.msgSender.toBuffer(),
statement.target.toBuffer(),
Buffer.from(statement.msgValue.toBuffer("le", 8)),
statement.encodedSigAndArgs,
Buffer.from(statement.policyId, "utf8"),
encodedSigAndArgsHash,
policyIdHash,
Buffer.from(statement.expiration.toBuffer("le", 8)),
]);

// Hash the data using SHA-256 (Solana's hash function)
return crypto.createHash("sha256").update(data).digest();
}

Expand Down