From 2fee72ea7a73e455adb665205f507912ec1826df Mon Sep 17 00:00:00 2001 From: shauray8 Date: Wed, 10 Apr 2024 23:43:41 +0530 Subject: [PATCH 1/3] move to a more forgiving distance --- src/diffusers/pipelines/stable_diffusion/safety_checker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 6cc4d26f29b4..64e426aca42b 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -28,6 +28,10 @@ def cosine_distance(image_embeds, text_embeds): normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) +def jaccard_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + class StableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig From bef9f56179951323c6da82ef5e177487074d1be2 Mon Sep 17 00:00:00 2001 From: shauray8 Date: Thu, 11 Apr 2024 00:05:24 +0530 Subject: [PATCH 2/3] something something jaccard distance --- .../pipelines/stable_diffusion/safety_checker.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 64e426aca42b..6d44b7b57d94 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -28,11 +28,13 @@ def cosine_distance(image_embeds, text_embeds): normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) -def jaccard_distance(image_embeds, text_embeds): - normalized_image_embeds = nn.functional.normalize(image_embeds) - normalized_text_embeds = nn.functional.normalize(text_embeds) - +def jaccard_distance(image_embeds, text_embeds, eps=1e-8): + scaler = torch.bmm(image_embeds.unsqueeze(1),text_embeds.unsqueeze(2)).squeeze(2)) + image_square = image_embeds.pow(2).sum(dim=-1, keepdim=True) + text_square = text_embeds.pow(2).sum(dim=-1, keepdim=True) + return scaler / (image_square + text_square.transpose(0,1) - scaler + eps) + class StableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig From 72aae526c8f009ddcb50a498f66ed6846a99da3a Mon Sep 17 00:00:00 2001 From: shauray8 Date: Thu, 11 Apr 2024 01:50:51 +0530 Subject: [PATCH 3/3] why safety models are stupid --- .../stable_diffusion/safety_checker.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 6d44b7b57d94..b57fd1ef4ae1 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -28,13 +28,16 @@ def cosine_distance(image_embeds, text_embeds): normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) -def jaccard_distance(image_embeds, text_embeds, eps=1e-8): - scaler = torch.bmm(image_embeds.unsqueeze(1),text_embeds.unsqueeze(2)).squeeze(2)) +## Seems to be working better for now, still not the best of safety models +def jaccard_distance(image_embeds, text_embeds, eps=-1): + scaler = torch.matmul(image_embeds, text_embeds.t()) image_square = image_embeds.pow(2).sum(dim=-1, keepdim=True) text_square = text_embeds.pow(2).sum(dim=-1, keepdim=True) + print((scaler / (image_square + text_square.transpose(0,1) - scaler + eps))*2) + print(f'{cosine_distance(image_embeds,text_embeds)=}') + return (scaler / (image_square + text_square.transpose(0,1) - scaler + eps))*2 + - return scaler / (image_square + text_square.transpose(0,1) - scaler + eps) - class StableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig @@ -58,8 +61,8 @@ def forward(self, clip_input, images): image_embeds = self.visual_projection(pooled_output) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() - cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + special_cos_dist = jaccard_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = jaccard_distance(image_embeds, self.concept_embeds).cpu().float().numpy() result = [] batch_size = image_embeds.shape[0] @@ -109,8 +112,8 @@ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor) pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) - cos_dist = cosine_distance(image_embeds, self.concept_embeds) + special_cos_dist = jaccard_distance(image_embeds, self.special_care_embeds) + cos_dist = jaccard_distance(image_embeds, self.concept_embeds) # increase this value to create a stronger `nsfw` filter # at the cost of increasing the possibility of filtering benign images @@ -129,3 +132,4 @@ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor) images[has_nsfw_concepts] = 0.0 # black image return images, has_nsfw_concepts +