Skip to content

Tracin group attribute#245

Open
Soulknight-T wants to merge 17 commits intoTRAIS-Lab:mainfrom
Soulknight-T:tracin_group_attribute
Open

Tracin group attribute#245
Soulknight-T wants to merge 17 commits intoTRAIS-Lab:mainfrom
Soulknight-T:tracin_group_attribute

Conversation

@Soulknight-T
Copy link
Copy Markdown

Description

Add Tracin test_dataloader_group to support memory efficient computation.
Add example/test of test_dataloader_group.

1. Motivation and Context

Add an alternative way to get score matrix with less memory used

2. Summary of the change

Add TestDataloaderGroup class to TracIn.py
Add test to test_tracin.py
Add usage example
Update GitHub example workflow

3. What tests have been added/updated for the change?

  • Unit test: Typically, this should be included if you implemented a new function/fixed a bug.

@Soulknight-T Soulknight-T marked this pull request as draft February 10, 2026 01:01
@Soulknight-T Soulknight-T marked this pull request as ready for review February 10, 2026 01:08
}


class DataloaderGroup:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inherit torch.utils.data.DataLoader

"""Initialize the DataloaderGroup.

Args:
original_test_dataloader (DataLoader): The underlying PyTorch dataloader.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • for individual test data samples

original_test_dataloader (DataLoader): The underlying PyTorch dataloader.
"""
self.original_test_dataloader = original_test_dataloader
self.batch_size = 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the batch_size is not used, then we can delete this member.

"""
self.original_test_dataloader = original_test_dataloader
self.batch_size = 1
self.sampler = [0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After inherit dataloader class, we don't need this member

else:
temp = sub_batch.to(self.device)

sub_grad = torch.nan_to_num(self.grad_target_func(parameters, temp))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to have a slighly different user-defined loss/target function (only shown in example) and avoid the changes in attributor.attribute()

def get_param(self, *args, **kwargs): return dict(model.named_parameters()), None
def get_grad_loss_func(self, *args, **kwargs): return func
def get_grad_target_func(self, *args, **kwargs):
return func_group
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For examples, we need to create a task using AttributionTask API.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But with a different target function (compared to other examples).

@TheaperDeng TheaperDeng changed the title Tracin group attribute [WIP] Tracin group attribute Feb 16, 2026
Copy link
Copy Markdown
Collaborator

@TheaperDeng TheaperDeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change accordingly

if hasattr(test_dataloader, "original_test_dataloader"):
_check_shuffle(test_dataloader.original_test_dataloader)
else:
_check_shuffle(test_dataloader)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this change.

input_dim, n_train, n_test = 2, 10, 5

model = nn.Linear(input_dim, 1, bias=False)
model.weight.data.fill_(1.0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use mnist_mlp

dattri/task.py Outdated
group_target_func (Callable): Optional. When attributing to a group (e.g. a
DataLoader passed via DataloaderGroup), this scalar function is used
instead of the per-sample target. Signature (params_dict, loader) -> scalar.
The gradient of this w.r.t. params is the test-side gradient for the group.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it a bool option, default to False. And please change the docstring of target_func, i.e., "when group_target_func=True, it should take the parameters ..."

dattri/task.py Outdated
g = grad(flat_group_target)(parameters)
return g.unsqueeze(0)

return base_grad_target(parameters, data)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please still call return self.grad_target_func if group_target_func=False and only return a separately wrapped function if group_target_func=True

normalized_grad=False,
device=args.device,
)
attributor.projector_kwargs = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this line?

model=model,
checkpoints=model.state_dict(),
target_func=f,
group_target_func=group_target_func,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...
target_func=group_target_func
group_target_func=True
...

print(f"Calculated Scores (first 10):\n{scores.flatten()[:10]}")
print(f"Calculated Scores Temp sum over test (first 10):\n{scores_temp.sum(dim=1)[:10]}")
diff = (scores.flatten() - scores_temp.sum(dim=1)).abs()
print(f"Max |group - sum(per-test)|: {diff.max().item():.6f}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the output of this script? Could you paste it here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test Dataloader Group (AttributionTask + group_target_func=True) — MNIST + MLP.
Score Shape: torch.Size([10000, 1])
Calculated Scores (first 10):
tensor([-2.2991e+00, 1.0665e-04, -1.4294e-01, 9.3012e-05, 1.6025e-01,
-2.3018e-02, 8.6976e-06, 1.5331e-07, -1.1255e-02, 8.6521e-07])
Calculated Scores Temp sum over test (first 10):
tensor([-2.2992e+00, 1.0665e-04, -1.4294e-01, 9.3012e-05, 1.6025e-01,
-2.3017e-02, 8.6975e-06, 1.5331e-07, -1.1255e-02, 8.6522e-07])
Max |group - sum(per-test)|: 0.005127

@TheaperDeng TheaperDeng changed the title [WIP] Tracin group attribute Tracin group attribute Mar 3, 2026
@TheaperDeng
Copy link
Copy Markdown
Collaborator

Please also fix the lint error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants