diff --git a/semilearn/algorithms/semireward/__init__.py b/semilearn/algorithms/semireward/__init__.py new file mode 100644 index 000000000..fedbee4c3 --- /dev/null +++ b/semilearn/algorithms/semireward/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .semireward import (add_gaussian_noise, cosine_similarity_n, + Generator, Rewarder, EMARewarder, label_dim) \ No newline at end of file diff --git a/semilearn/algorithms/semireward/semireward.py b/semilearn/algorithms/semireward/semireward.py new file mode 100644 index 000000000..a49bfbbdc --- /dev/null +++ b/semilearn/algorithms/semireward/semireward.py @@ -0,0 +1,148 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Generator(nn.Module): + """Fake Label Generator in SemiReward""" + + def __init__(self, feature_dim=384): + super(Generator, self).__init__() + self.fc_layers = nn.Sequential( + nn.Linear(feature_dim, 256), + nn.ReLU(), + nn.Linear(256, 128), + nn.ReLU(), + nn.Linear(128, 64), + nn.ReLU(), + nn.Linear(64, 1) + ) + + def forward(self, x): + x = self.fc_layers(x) + x = F.relu(x) + return x + + +class Rewarder(nn.Module): + """Pseudo Label Reward in SemiReward""" + + def __init__(self, label_dim, label_embedding_dim, feature_dim=384): + super(Rewarder, self).__init__() + + # Feature Processing Part + self.feature_fc = nn.Linear(feature_dim, 128) + self.feature_norm = nn.LayerNorm(128) + + # Label Embedding Part + self.label_embedding = nn.Embedding(label_dim, label_embedding_dim) + self.label_norm = nn.LayerNorm(label_embedding_dim) + + # Cross-Attention Mechanism + self.cross_attention_fc = nn.Linear(128, 1) + + # MLP (Multi-Layer Perceptron) + self.mlp_fc1 = nn.Linear(128, 256) + self.mlp_fc2 = nn.Linear(256, 128) + + # Feed-Forward Network (FFN) + self.ffn_fc1 = nn.Linear(128, 64) + self.ffn_fc2 = nn.Linear(64, 1) + + def forward(self, features, label_indices): + # Process Features + features = self.feature_fc(features) + features = self.feature_norm(features) + # Process Labels + label_embed = self.label_embedding(label_indices) + label_embed = self.label_norm(label_embed) + # Cross-Attention Mechanism + cross_attention_input = torch.cat((features, label_embed), dim=0) + cross_attention_weights = torch.softmax(self.cross_attention_fc(cross_attention_input), dim=0) + cross_attention_output = (cross_attention_weights * cross_attention_input).sum(dim=0) + + # MLP Part + mlp_input = torch.add(cross_attention_output.unsqueeze(0).expand(label_embed.size(0), -1), label_embed) + mlp_output = F.relu(self.mlp_fc1(mlp_input)) + mlp_output = self.mlp_fc2(mlp_output) + + # FFN Part + ffn_output = F.relu(self.ffn_fc1(mlp_output)) + reward = torch.sigmoid(self.ffn_fc2(ffn_output)) + return reward + + +class EMARewarder(Rewarder): + """EMA version of Reward in SemiReward""" + + def __init__(self, label_dim, label_embedding_dim, feature_dim=384, ema_decay=0.9): + super(EMARewarder, self).__init__( + label_dim=label_dim, label_embedding_dim=label_embedding_dim, feature_dim=feature_dim) + + # EMA decay rate + self.ema_decay = ema_decay + + # Initialize EMA parameters + self.ema_params = {} + self.initialize_ema() + + def initialize_ema(self): + for name, param in self.named_parameters(): + if param.requires_grad: + self.ema_params[name] = nn.Parameter(param.data.clone()) + + def update_ema(self): + for name, param in self.named_parameters(): + if param.requires_grad: + ema_param = self.ema_params[name] + if ema_param.device != param.device: + ema_param.data = param.data.clone().to(ema_param.device) + else: + ema_param.data.mul_(self.ema_decay).add_((1 - self.ema_decay) * param.data) + + def forward(self, features, label_indices): + # Process Features + features = self.feature_fc(features) + features = self.feature_norm(features) + # Process Labels + label_embed = self.label_embedding(label_indices) + label_embed = self.label_norm(label_embed) + # Cross-Attention Mechanism + cross_attention_input = torch.cat((features, label_embed), dim=0) + cross_attention_weights = torch.softmax(self.cross_attention_fc(cross_attention_input), dim=0) + cross_attention_output = (cross_attention_weights * cross_attention_input).sum(dim=0) + + # MLP Part + mlp_input = torch.add(cross_attention_output.unsqueeze(0).expand(label_embed.size(0), -1), label_embed) + mlp_output = F.relu(self.mlp_fc1(mlp_input)) + mlp_output = self.mlp_fc2(mlp_output) + + # FFN Part + ffn_output = F.relu(self.ffn_fc1(mlp_output)) + reward = torch.sigmoid(self.ffn_fc2(ffn_output)) + + # Update EMA parameters + self.update_ema() + + return reward + + +def cosine_similarity_n(x, y): + + # Calculate cosine similarity along the last dimension (dim=-1) + cosine_similarity = torch.cosine_similarity(x, y, dim=-1, eps=1e-8) + + # Reshape the result to [first size of x, 1] + normalized_similarity = (cosine_similarity + 1) / 2 + normalized_similarity = normalized_similarity.view(x.size(0), 1) + + return normalized_similarity + + +def add_gaussian_noise(tensor, mean=0, std=1): + noise = torch.randn_like(tensor) * std + mean + noisy_tensor = tensor + noise + return noisy_tensor + +def label_dim(x, default_dim=100): + return int(max(default_dim, x)) \ No newline at end of file diff --git a/semilearn/algorithms/srflexmatch/__init__.py b/semilearn/algorithms/srflexmatch/__init__.py new file mode 100644 index 000000000..865507c7a --- /dev/null +++ b/semilearn/algorithms/srflexmatch/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .srflexmatch import SRFlexMatch +from .utils import FlexMatchThresholdingHook \ No newline at end of file diff --git a/semilearn/algorithms/srflexmatch/srflexmatch.py b/semilearn/algorithms/srflexmatch/srflexmatch.py new file mode 100644 index 000000000..b8d44b484 --- /dev/null +++ b/semilearn/algorithms/srflexmatch/srflexmatch.py @@ -0,0 +1,246 @@ + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn.functional as F +from .utils import FlexMatchThresholdingHook +from semilearn.core import AlgorithmBase +from semilearn.core.utils import ALGORITHMS, send_model_cuda +from semilearn.algorithms.hooks import PseudoLabelingHook +from semilearn.algorithms.utils import SSL_Argument, str2bool +from semilearn.algorithms.semireward import Rewarder, Generator, EMARewarder, cosine_similarity_n, label_dim + + +@ALGORITHMS.register('srflexmatch') +class SRFlexMatch(AlgorithmBase): + """ + FlexMatch algorithm (https://arxiv.org/abs/2110.08263). + SemiReward algorithm (https://arxiv.org/abs/2310.03013). + + Args: + - args (`argparse`): + algorithm arguments + - net_builder (`callable`): + network loading function + - tb_log (`TBLog`): + tensorboard logger + - logger (`logging.Logger`): + logger to use + - T (`float`): + Temperature for pseudo-label sharpening + - p_cutoff(`float`): + Confidence threshold for generating pseudo-labels + - hard_label (`bool`, *optional*, default to `False`): + If True, targets have [Batch size] shape with int values. If False, the target is vector + - ulb_dest_len (`int`): + Length of unlabeled data + - thresh_warmup (`bool`, *optional*, default to `True`): + If True, warmup the confidence threshold, so that at the beginning of the training, all estimated + learning effects gradually rise from 0 until the number of unused unlabeled data is no longer + predominant + + """ + def __init__(self, args, net_builder, tb_log=None, logger=None): + super().__init__(args, net_builder, tb_log, logger) + # flexmatch specified arguments + self.init(T=args.T, p_cutoff=args.p_cutoff, hard_label=args.hard_label, thresh_warmup=args.thresh_warmup) + self.N_k = args.N_k + self.rewarder = send_model_cuda(args, Rewarder(label_dim(self.num_classes), 128, args.feature_dim)) if args.sr_ema == 0 \ + else send_model_cuda(args, EMARewarder(label_dim(self.num_classes), 128, feature_dim=args.feature_dim, ema_decay=args.sr_ema_m), clip_batch=False) + self.generator = send_model_cuda(args, Generator(args.feature_dim)) + self.start_timing = args.start_timing + + self.rewarder_optimizer = torch.optim.Adam(self.rewarder.parameters(), lr=args.sr_lr) + self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=args.sr_lr) + + self.criterion = torch.nn.MSELoss() + + self.max_reward = -float('inf') + def init(self, T, p_cutoff, hard_label=True, thresh_warmup=True): + self.T = T + self.p_cutoff = p_cutoff + self.use_hard_label = hard_label + self.thresh_warmup = thresh_warmup + + def set_hooks(self): + self.register_hook(PseudoLabelingHook(), "PseudoLabelingHook") + self.register_hook(FlexMatchThresholdingHook( + ulb_dest_len=self.args.ulb_dest_len, num_classes=self.num_classes, thresh_warmup=self.args.thresh_warmup), "MaskingHook") + super().set_hooks() + + def data_generator(self, x_lb, y_lb, idx_ulb, x_ulb_w, x_ulb_s, rewarder,gpu): + gpu = gpu + rewarder = rewarder.eval() + for _ in range(self.sr_decay()): + num_lb = y_lb.shape[0] + with self.amp_cm(): + if self.use_cat: + inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s)) + outputs = self.model(inputs) + logits_x_ulb_w, logits_x_ulb_s = outputs['logits'][num_lb:].chunk(2) + feats_x_ulb_w, feats_x_ulb_s = outputs['feat'][num_lb:].chunk(2) + else: + outs_x_ulb_s = self.model(x_ulb_s) + logits_x_ulb_s = outs_x_ulb_s['logits'] + feats_x_ulb_s = outs_x_ulb_s['feat'] + with torch.no_grad(): + outs_x_ulb_w = self.model(x_ulb_w) + logits_x_ulb_w = outs_x_ulb_w['logits'] + feats_x_ulb_w = outs_x_ulb_w['feat'] + + probs_x_ulb_w = self.compute_prob(logits_x_ulb_w.detach()) + pseudo_label = self.call_hook("gen_ulb_targets", "PseudoLabelingHook", + logits=probs_x_ulb_w, + use_hard_label=self.use_hard_label, + T=self.T, + softmax=False) + mask = self.call_hook("masking", "MaskingHook", logits_x_ulb=probs_x_ulb_w, softmax_x_ulb=False, idx_ulb=idx_ulb) + reward = rewarder(feats_x_ulb_w, pseudo_label) + avg_reward=reward.mean() + mask2 = torch.where(reward >= avg_reward, torch.tensor(1).cuda(gpu), torch.tensor(0).cuda(gpu)).squeeze().float() + unsup_loss = self.consistency_loss(logits_x_ulb_s, pseudo_label,'ce', mask=mask,mask2=mask2) + unsup_loss = unsup_loss + return unsup_loss + + + def train_step(self, x_lb, y_lb, idx_ulb, x_ulb_w, x_ulb_s): + num_lb = y_lb.shape[0] + + # inference and calculate sup/unsup losses + with self.amp_cm(): + if self.use_cat: + inputs = torch.cat((x_lb, x_ulb_w, x_ulb_s)) + outputs = self.model(inputs) + logits_x_lb = outputs['logits'][:num_lb] + logits_x_ulb_w, logits_x_ulb_s = outputs['logits'][num_lb:].chunk(2) + feats_x_lb = outputs['feat'][:num_lb] + feats_x_ulb_w, feats_x_ulb_s = outputs['feat'][num_lb:].chunk(2) + else: + outs_x_lb = self.model(x_lb) + logits_x_lb = outs_x_lb['logits'] + feats_x_lb = outs_x_lb['feat'] + outs_x_ulb_s = self.model(x_ulb_s) + logits_x_ulb_s = outs_x_ulb_s['logits'] + feats_x_ulb_s = outs_x_ulb_s['feat'] + with torch.no_grad(): + outs_x_ulb_w = self.model(x_ulb_w) + logits_x_ulb_w = outs_x_ulb_w['logits'] + feats_x_ulb_w = outs_x_ulb_w['feat'] + feat_dict = {'x_lb':feats_x_lb, 'x_ulb_w':feats_x_ulb_w, 'x_ulb_s':feats_x_ulb_s} + + sup_loss = self.ce_loss(logits_x_lb, y_lb, reduction='mean') + + # probs_x_ulb_w = torch.softmax(logits_x_ulb_w, dim=-1) + probs_x_ulb_w = self.compute_prob(logits_x_ulb_w.detach()) + + # if distribution alignment hook is registered, call it + # this is implemented for imbalanced algorithm - CReST + if self.registered_hook("DistAlignHook"): + probs_x_ulb_w = self.call_hook("dist_align", "DistAlignHook", probs_x_ulb=probs_x_ulb_w.detach()) + mask = self.call_hook("masking", "MaskingHook", logits_x_ulb=probs_x_ulb_w, softmax_x_ulb=False, idx_ulb=idx_ulb) + pseudo_label = self.call_hook("gen_ulb_targets", "PseudoLabelingHook", + logits=probs_x_ulb_w, + use_hard_label=self.use_hard_label, + T=self.T, + softmax=False) + if self.it > self.start_timing: + rewarder = self.rewarder + unsup_loss = self.data_generator(x_lb, y_lb,idx_ulb, x_ulb_w, x_ulb_s,rewarder,self.gpu) + else: + pseudo_label = pseudo_label + unsup_loss = self.consistency_loss(logits_x_ulb_s, pseudo_label,'ce', mask=mask) + + if self.it > 0: + # Generate pseudo labels using the generator (your pseudo-labeling process) + self.rewarder.train() + self.generator.train() + generated_label = self.generator(feats_x_lb.detach()) + generated_label=generated_label.long() + # Convert generated pseudo labels and true labels to tensors + real_labels_tensor = y_lb.cuda(self.gpu) + reward = self.rewarder(feats_x_lb.detach(),generated_label.squeeze(1)) + if self.it >= self.start_timing: + filtered_pseudo_labels = pseudo_label.long() + filtered_feats_x_ulb_w = feats_x_ulb_w.detach() + rewarder = self.rewarder.eval() + + reward = self.rewarder(feats_x_ulb_w.detach(), pseudo_label.long()) + reward = reward.mean() + self.max_reward = torch.where(reward > self.max_reward, reward, self.max_reward) + filtered_pseudo_labels = torch.where(reward > self.max_reward, pseudo_label.detach(), filtered_pseudo_labels) + filtered_feats_x_ulb_w = torch.where(reward > self.max_reward, feats_x_ulb_w.detach(), filtered_feats_x_ulb_w) + if self.it % self.N_k == 0 and self.it > self.start_timing: + self.max_reward = -float('inf') + self.rewarder.train() + self.generator.train() + generated_label = self.generator(filtered_feats_x_ulb_w.squeeze(1)) + generated_label=generated_label.long() + reward = self.rewarder(filtered_feats_x_ulb_w, generated_label.squeeze(1)) + generated_label = F.one_hot(generated_label.squeeze(1), num_classes=self.num_classes) + filtered_pseudo_labels= F.one_hot(filtered_pseudo_labels.long(), num_classes=self.num_classes) + cosine_similarity_score = cosine_similarity_n(generated_label.float(), filtered_pseudo_labels.float()) + generator_loss = self.criterion(reward, torch.ones_like(reward).cuda(self.gpu)) + rewarder_loss = self.criterion(reward, cosine_similarity_score) + + self.generator_optimizer.zero_grad() + self.rewarder_optimizer.zero_grad() + + generator_loss.backward(retain_graph=True) + rewarder_loss.backward(retain_graph=True) + + self.generator_optimizer.step() + self.rewarder_optimizer.step() + else: + generated_label = F.one_hot(generated_label.squeeze(1), num_classes=self.num_classes) + real_labels_tensor=F.one_hot(real_labels_tensor, num_classes=self.num_classes) + cosine_similarity_score = cosine_similarity_n(generated_label.float(), real_labels_tensor.float()) + generator_loss = self.criterion(reward, torch.ones_like(reward).cuda(self.gpu)) + rewarder_loss = self.criterion(reward, cosine_similarity_score) + + self.generator_optimizer.zero_grad() + self.rewarder_optimizer.zero_grad() + + generator_loss.backward(retain_graph=True) + rewarder_loss.backward(retain_graph=True) + + self.generator_optimizer.step() + self.rewarder_optimizer.step() + + total_loss = sup_loss + self.lambda_u * unsup_loss + + out_dict = self.process_out_dict(loss=total_loss, feat=feat_dict) + log_dict = self.process_log_dict(sup_loss=sup_loss.item(), + unsup_loss=unsup_loss.item(), + total_loss=total_loss.item(), + util_ratio=mask.float().mean().item()) + return out_dict, log_dict + + def get_save_dict(self): + save_dict = super().get_save_dict() + # additional saving arguments + save_dict['classwise_acc'] = self.hooks_dict['MaskingHook'].classwise_acc.cpu() + save_dict['selected_label'] = self.hooks_dict['MaskingHook'].selected_label.cpu() + return save_dict + + def load_model(self, load_path): + checkpoint = super().load_model(load_path) + self.hooks_dict['MaskingHook'].classwise_acc = checkpoint['classwise_acc'].cuda(self.gpu) + self.hooks_dict['MaskingHook'].selected_label = checkpoint['selected_label'].cuda(self.gpu) + self.print_fn("additional parameter loaded") + return checkpoint + + @staticmethod + def get_argument(): + return [ + SSL_Argument('--hard_label', str2bool, True), + SSL_Argument('--T', float, 0.5), + SSL_Argument('--p_cutoff', float, 0.95), + SSL_Argument('--thresh_warmup', str2bool, True), + SSL_Argument('--start_timing', int,20000), + SSL_Argument('--feature_dim', int,384), + SSL_Argument('--sr_lr', float, 0.0005), + SSL_Argument('--N_k', int, 10), + SSL_Argument('--sr_ema', str2bool, True), + SSL_Argument('--sr_ema_m', float, 0.999), + ] diff --git a/semilearn/algorithms/srflexmatch/utils.py b/semilearn/algorithms/srflexmatch/utils.py new file mode 100644 index 000000000..2c32d1498 --- /dev/null +++ b/semilearn/algorithms/srflexmatch/utils.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from copy import deepcopy +from collections import Counter + +from semilearn.algorithms.hooks import MaskingHook + + +class FlexMatchThresholdingHook(MaskingHook): + """ + Adaptive Thresholding in FlexMatch + """ + def __init__(self, ulb_dest_len, num_classes, thresh_warmup=True, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ulb_dest_len = ulb_dest_len + self.num_classes = num_classes + self.thresh_warmup = thresh_warmup + self.selected_label = torch.ones((self.ulb_dest_len,), dtype=torch.long, ) * -1 + self.classwise_acc = torch.zeros((self.num_classes,)) + + @torch.no_grad() + def update(self, *args, **kwargs): + pseudo_counter = Counter(self.selected_label.tolist()) + if max(pseudo_counter.values()) < self.ulb_dest_len: # not all(5w) -1 + if self.thresh_warmup: + for i in range(self.num_classes): + self.classwise_acc[i] = pseudo_counter[i] / max(pseudo_counter.values()) + else: + wo_negative_one = deepcopy(pseudo_counter) + if -1 in wo_negative_one.keys(): + wo_negative_one.pop(-1) + for i in range(self.num_classes): + self.classwise_acc[i] = pseudo_counter[i] / max(wo_negative_one.values()) + + @torch.no_grad() + def masking(self, algorithm, logits_x_ulb, idx_ulb, softmax_x_ulb=True, *args, **kwargs): + if not self.selected_label.is_cuda: + self.selected_label = self.selected_label.to(logits_x_ulb.device) + if not self.classwise_acc.is_cuda: + self.classwise_acc = self.classwise_acc.to(logits_x_ulb.device) + + if softmax_x_ulb: + # probs_x_ulb = torch.softmax(logits_x_ulb.detach(), dim=-1) + probs_x_ulb = self.compute_prob(logits_x_ulb.detach()) + else: + # logits is already probs + probs_x_ulb = logits_x_ulb.detach() + max_probs, max_idx = torch.max(probs_x_ulb, dim=-1) + # mask = max_probs.ge(p_cutoff * (class_acc[max_idx] + 1.) / 2).float() # linear + # mask = max_probs.ge(p_cutoff * (1 / (2. - class_acc[max_idx]))).float() # low_limit + mask = max_probs.ge(algorithm.p_cutoff * (self.classwise_acc[max_idx] / (2. - self.classwise_acc[max_idx]))) # convex + # mask = max_probs.ge(p_cutoff * (torch.log(class_acc[max_idx] + 1.) + 0.5)/(math.log(2) + 0.5)).float() # concave + select = max_probs.ge(algorithm.p_cutoff) + mask = mask.to(max_probs.dtype) + + # update + if idx_ulb[select == 1].nelement() != 0: + self.selected_label[idx_ulb[select == 1]] = max_idx[select == 1] + self.update() + + return mask + + + +