Skip to content

Why didn't I reproduce Sam3 adapter's more accurate delineation of polyp contours than Sam3? #133

@jackcjp

Description

@jackcjp

为什么我并没有复现sam3-adapter比sam3更好的轮廓分割(sam3不完整,边界模糊)?

1.我使用的是kvasir数据集,使用的是main分支SAM3-Adapter-Pytorch(2).zip代码来尝试复现的

2.我对kvasir数据集拆分为train 800张,val100张,test100张,你知乎上figture6中的4张样本刚好不在test中(以下这4张称为关键样本)
cju15czxqp3lv0835jvhgzurz.jpg -> val
cju7ebe962hr409872ovibahw.jpg -> train
cju16whaj0e7n0855q7b6cjkm.jpg -> train
ck2bxw18mmz1k0725litqq2mc.jpg -> train

3.我使用
eval_type: kvasir
sam_checkpoint: /mnt/tempdata/chenjp/sam-adapter/SAM-Adapter-PyTorch/pretrained/sam3.pt
config: /mnt/tempdata/chenjp/sam-adapter/SAM3-Adapter-Pytorch/configs/cod-sam-vit-l.yaml(只改动了train_dataset,val_dataset,test_dataset的路径)

train_dataset,val_dataset,test_dataset分别对应上面的train/val/test,运行train.sh得到的pth文件
/mnt/tempdata/chenjp/sam-adapter/SAM3-Adapter-Pytorch/save/_cod-sam-vit-l/model_epoch_best.pth

4.我使用同样的config cod-sam-vit-l.yaml,使用步骤3中的model_epoch_best.pth,运行test.sh得到的dice和iou为

# ################ Results ################  (sam3-adapter)                                                                                                                                                                                                                                   
# 0.9120600180868292: 0.9121
# 0.8430836232751606: 0.8431
# 0.0: 0.0000
# 0.0: 0.0000
# #########################################
  1. 我使用sam3官方的示例代码用test组的样本数据和GT mask做bbox提示,给sam3做输入,拿到预测值做了iou
# ========================================
# Final mIoU (sam3 Style): 0.8466
# Results saved to: /mnt/tempdata/chenjp/sam-adapter/results_sam3_style_kvasir
# ========================================
  1. 我看了知乎上你发的文章说sam3-adapter的效果比sam3分割结果更完整,边界清晰,我试了感觉效果不明显,
    就如第2点中所列的几个关键样本,我单独用第3步的model_epoch_best.pth,没有感觉到sam3-adapter比sam3更好,
    使用上面步骤2的4个关键样本来跑test.sh得到的结果是 (sam3-adapter)
# ################ Results ################                                                                                                                                                                                                                                     
# 0.8602661245424534: 0.8603
# 0.7395057827234268: 0.7395
# 0.0: 0.0000
# 0.0: 0.0000
# #########################################

sam3脚本得到的结果是

# ========================================
# Final mIoU (sam3 Style): 0.7303
# Results saved to: /mnt/tempdata/chenjp/sam-adapter/results_sam3_style_kvasir_re_pred
# ========================================

问题:
1.我应该怎么做才能复现这种效果?

2.还想知道你用的的sam3的脚本在代码中吗?权重文件是pt格式的吗?
3.我需要把那4个关键样本从train/val中剔除重新生成一个model_epoch_best.pth文件来跑这几个关键样本吗?还是说我使用你readme中列的其他数据集来验证?

4.我看到咱sam3-adapter的论文文章中有对比sam3和sam3-adapter,但是table1,2,3中都有sam,sam2,但没有列sam3的指标,是为什么?是sam3的指标更好,无法印证sam3-adapter的价值,不便于展示?还是疏忽?(我猜是adapter可能在sam、sam2时确实有较大提升,但在sam3时几乎无提升,或不太明显,故而不列sam3来,只凸显sam3-adapter是sota,使论文能顺利发表?)

sam3的结果与sam3-adapter的结果比较(上面是sam3下面是sam3-adapter)
Image
Image

Image Image Image Image Image Image

使用sam3分割的代码,test_sam3_style_kvasir.py

import os
import glob
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

# ================= 配置路径 =================
sys.path.insert(0, "/mnt/tempdata/chenjp/sam3")
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor

# 数据路径
IMAGE_ROOT = '/mnt/tempdata/chenjp/sam-adapter/re-pred/images'
MASK_ROOT = '/mnt/tempdata/chenjp/sam-adapter/re-pred/masks'
MODEL_DIR = "/mnt/tempdata/chenjp/sam-adapter/sam3"

# 输出结果
OUTPUT_DIR = "/mnt/tempdata/chenjp/sam-adapter/results_sam3_style_kvasir_re_pred"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ================= 核心:官方 sam3 的坐标转换逻辑 =================

def box_xyxy_to_cxcywh(x1, y1, x2, y2):
    """
    将 [x1, y1, x2, y2] 转换为 [cx, cy, w, h]
    """
    w = x2 - x1
    h = y2 - y1
    cx = x1 + w / 2
    cy = y1 + h / 2
    return [cx, cy, w, h]

def normalize_bbox(box_cxcywh, img_w, img_h):
    """
    归一化 [cx, cy, w, h] 到 [0, 1]
    """
    cx, cy, w, h = box_cxcywh
    norm_box = [
        cx / img_w,
        cy / img_h,
        w / img_w,
        h / img_h
    ]
    return norm_box

# ================= 工具函数 =================

def get_bbox_from_mask(mask):
    """从 GT Mask 获取 [x1, y1, x2, y2] (绝对像素)"""
    y_indices, x_indices = np.where(mask > 0)
    if len(y_indices) == 0: return None
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    return [x_min, y_min, x_max, y_max]

def calculate_iou(pred_mask, gt_mask):
    intersection = np.logical_and(pred_mask, gt_mask).sum()
    union = np.logical_or(pred_mask, gt_mask).sum()
    return intersection / union if union > 0 else 0.0

def save_visualization(image_np, box_xyxy, pred_mask, gt_mask, iou, score, filename):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # 1. 原图 + Box
    axes[0].imshow(image_np)
    if box_xyxy is not None:
        x1, y1, x2, y2 = box_xyxy
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='red', facecolor='none')
        axes[0].add_patch(rect)
    axes[0].set_title(f"Input: Image + Box\n{box_xyxy}")
    axes[0].axis('off')
    
    # 2. 预测 Mask (黑底白字)
    axes[1].imshow(pred_mask, cmap='gray', vmin=0, vmax=1)
    axes[1].set_title(f"Pred Mask (Score: {score:.3f})\nIoU: {iou:.4f}")
    axes[1].axis('off')
    
    # 3. GT Mask
    axes[2].imshow(gt_mask, cmap='gray', vmin=0, vmax=1)
    axes[2].set_title("GT Mask")
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, f"{filename}_viz.png"))
    plt.close(fig)

# ================= 主流程 =================

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # 1. 加载模型
    print("Loading SAM 3 model...")
    model = build_sam3_image_model(
        checkpoint_path=os.path.join(MODEL_DIR, "sam3.pt"), 
        device=device, eval_mode=True, load_from_HF=False
    )
    processor = Sam3Processor(model, device=device)

    # 2. 准备数据
    img_files = sorted(glob.glob(os.path.join(IMAGE_ROOT, "*.jpg")) + glob.glob(os.path.join(IMAGE_ROOT, "*.png")))
    print(f"Testing {len(img_files)} images using sam3 Logic (Norm CXCYWH)...")

    all_ious = []

    for i in tqdm(range(len(img_files))):
        img_path = img_files[i]
        filename = os.path.basename(img_path)
        
        # 匹配 Mask
        basename = os.path.splitext(filename)[0]
        mask_path = None
        for p in [os.path.join(MASK_ROOT, filename), os.path.join(MASK_ROOT, basename + ".png")]:
            if os.path.exists(p): mask_path = p; break
        if not mask_path: continue

        try:
            # 读取数据
            image = Image.open(img_path).convert("RGB")
            w, h = image.size
            image_np = np.array(image)
            
            gt_img = Image.open(mask_path).convert('L').resize((w, h), Image.NEAREST)
            gt_mask = np.array(gt_img) > 128
            
            # 获取绝对像素 Box [x1, y1, x2, y2]
            bbox_xyxy = get_bbox_from_mask(gt_mask)
            if bbox_xyxy is None: continue
            
            # -------------------------------------------------
            # 关键修改:严格遵循 sam3 逻辑
            # -------------------------------------------------
            
            # 1. 像素 XYXY -> 像素 CXCYWH
            box_cxcywh = box_xyxy_to_cxcywh(*bbox_xyxy)
            
            # 2. 像素 CXCYWH -> 归一化 CXCYWH (0.0-1.0)
            norm_box_cxcywh = normalize_bbox(box_cxcywh, w, h)
            
            # 3. 设置图像
            inference_state = processor.set_image(image)
            
            # 4. 重置 Prompts (sam3 里强调了这步)
            processor.reset_all_prompts(inference_state)
            
            # 5. 添加几何提示
            # 注意:sam3 里 box 参数传的是 flatten list,不需要嵌套 batch
            output = processor.add_geometric_prompt(
                state=inference_state, 
                box=norm_box_cxcywh, 
                label=True
            )
            
            # -------------------------------------------------
            
            masks = output["masks"]
            scores = output["scores"]
            
            # 提取最高分 Mask
            if isinstance(scores, list): 
                best_idx = np.argmax(scores)
                score_val = scores[best_idx]
            else: 
                best_idx = torch.argmax(scores).item()
                score_val = scores[best_idx].item()
            
            if isinstance(masks, list): pred_tensor = masks[best_idx]
            else: pred_tensor = masks[best_idx]
            
            if pred_tensor.ndim == 3: pred_tensor = pred_tensor.squeeze(0)
            
            # 后处理 Resize 回原图 (通常 SAM 输出的是内部尺寸)
            if pred_tensor.shape[0] != h or pred_tensor.shape[1] != w:
                pred_resized = torch.nn.functional.interpolate(
                    pred_tensor.unsqueeze(0).unsqueeze(0).float(),
                    size=(h, w), 
                    mode='bilinear', 
                    align_corners=False
                ).squeeze()
            else:
                pred_resized = pred_tensor
                
            pred_prob = pred_resized.cpu().numpy()
            pred_binary = (pred_prob > 0.5)
            
            # 计算 IoU
            iou = calculate_iou(pred_binary, gt_mask)
            all_ious.append(iou)
            
            # 可视化 (Box 画图时还是用 xyxy,方便观察)
            save_visualization(image_np, bbox_xyxy, pred_binary, gt_mask, iou, score_val, filename)

        except Exception as e:
            print(f"Error {filename}: {e}")
            # import traceback
            # traceback.print_exc()
            continue

    print("\n" + "="*40)
    if all_ious:
        print(f"Final mIoU (sam3 Style): {np.mean(all_ious):.4f}")
        print(f"Results saved to: {OUTPUT_DIR}")
    else:
        print("No results.")
    print("="*40)

if __name__ == "__main__":
    main()



@tianrun-chen

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions