You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
修正代码如下(不保证正确,跑出后的结果我会附图):r"""
CT_NCM
##################################
Reference:
Haiping Ma et al. "Reconciling cognitive modeling with knowledge forgetting: A continuous time-aware neural network approach." in IJCAI 2022.
Reference code: https://github.com/BIMK/Intelligent-Education/tree/main/CTNCM
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..gd_basemodel import GDBaseModel
class CT_NCM(GDBaseModel):
"""
hidden_size: dimensions of LSTM hidden layers
embed_size: dimensions of student-knowledge concept interaction embedding
prelen1: the first layer of performance prediction
prelen2: the second layer of performance prediction
dropout1: the proportion of first fully connected layer dropout before getting the prediction score
dropout2: the proportion of second fully connected layer dropout before getting the prediction score
"""
default_cfg = {
'hidden_size': 64,
'embed_size': 64,
'prelen1': 256,
'prelen2': 128,
'dropout1': 0,
'dropout2': 0,
}
def __init__(self, cfg):
"""Pass parameters from other templates into the model
Args:
cfg (UnifyConfig): parameters from other templates
"""
super().__init__(cfg)
def build_cfg(self):
"""Initialize the parameters of the model"""
self.problem_num = self.datatpl_cfg['dt_info']['exer_count']
self.skill_num = self.datatpl_cfg['dt_info']['cpt_count']
self.device = self.traintpl_cfg['device']
self.hidden_size = self.modeltpl_cfg['hidden_size']
self.embed_size = self.modeltpl_cfg['embed_size']
self.knowledge_dim = self.hidden_size
self.input_len = self.knowledge_dim
self.prelen1 = self.modeltpl_cfg['prelen1']
self.prelen2 = self.modeltpl_cfg['prelen2']
# 使用 BCEWithLogitsLoss,因为模型输出未经过 Sigmoid 激活
self.loss_function = torch.nn.BCEWithLogitsLoss()
# 如果选择使用 BCELoss,请取消上面一行并启用下面一行
# self.loss_function = torch.nn.BCELoss()
def build_model(self):
"""Initialize the various components of the model"""
self.dropout1 = nn.Dropout(p=self.modeltpl_cfg['dropout1'])
self.dropout2 = nn.Dropout(p=self.modeltpl_cfg['dropout2'])
self.inter_embedding = torch.nn.Embedding(2 * self.skill_num, self.embed_size)
self.reclstm = torch.nn.Linear(self.embed_size + self.hidden_size, 7 * self.hidden_size)
self.problem_disc = torch.nn.Embedding(self.problem_num, 1)
self.problem_diff = torch.nn.Embedding(self.problem_num, self.knowledge_dim)
self.linear1 = torch.nn.Linear(self.input_len, self.prelen1)
self.linear2 = torch.nn.Linear(self.prelen1, self.prelen2)
self.linear3 = torch.nn.Linear(self.prelen2, 1)
# 手动初始化权重
nn.init.xavier_normal_(self.reclstm.weight)
nn.init.xavier_normal_(self.linear1.weight)
nn.init.xavier_normal_(self.linear2.weight)
nn.init.xavier_normal_(self.linear3.weight)
def forward(self, exer_seq, start_timestamp_seq, cpt_unfold_seq, label_seq, mask_seq, **kwargs):
"""A function of how well the model predicts students' responses to exercise questions
Args:
exer_seq (torch.Tensor): Sequence of exercise id. Shape of [batch_size, seq_len]
start_timestamp_seq (torch.Tensor): Sequence of Students start answering time. Shape of [batch_size, seq_len]
cpt_unfold_seq (torch.Tensor): Sequence of knowledge concepts related to exercises. Shape of [batch_size, seq_len]
label_seq (torch.Tensor): Sequence of students' answers to exercises. Shape of [batch_size, seq_len]
mask_seq (torch.Tensor): Sequence of mask. Mask=1 indicates that the student has answered the exercise, otherwise vice versa. Shape of [batch_size, seq_len]
Returns:
dict: The predictions of the model and the real situation
"""
# 检查输入数据中是否存在 NaN 或 Inf
assert not torch.isnan(exer_seq).any(), "exer_seq contains NaNs"
assert not torch.isnan(start_timestamp_seq).any(), "start_timestamp_seq contains NaNs"
assert not torch.isnan(cpt_unfold_seq).any(), "cpt_unfold_seq contains NaNs"
assert not torch.isnan(label_seq).any(), "label_seq contains NaNs"
assert not torch.isnan(mask_seq).any(), "mask_seq contains NaNs"
assert not torch.isinf(exer_seq).any(), "exer_seq contains Inf"
assert not torch.isinf(start_timestamp_seq).any(), "start_timestamp_seq contains Inf"
assert not torch.isinf(cpt_unfold_seq).any(), "cpt_unfold_seq contains Inf"
assert not torch.isinf(label_seq).any(), "label_seq contains Inf"
assert not torch.isinf(mask_seq).any(), "mask_seq contains Inf"
problem_seqs_tensor = exer_seq[:, 1:].to(self.device)
skill_seqs_tensor = cpt_unfold_seq.to(self.device)
start_timestamp_seqs_tensor = start_timestamp_seq[:, 1:].to(self.device)
correct_seqs_tensor = label_seq.to(self.device)
mask_labels = mask_seq.long().to(self.device)
seqs_length = torch.sum(mask_labels, dim=1)
delete_row = 0
for i in range(len(seqs_length)):
if seqs_length[i] == 1:
mask = torch.arange(problem_seqs_tensor.size(0)) != (i - delete_row)
problem_seqs_tensor = problem_seqs_tensor[mask]
skill_seqs_tensor = skill_seqs_tensor[mask]
start_timestamp_seqs_tensor = start_timestamp_seqs_tensor[mask]
correct_seqs_tensor = correct_seqs_tensor[mask]
mask_labels = mask_labels[mask]
delete_row += 1
# 将 mask_labels == 0 的标签设置为 0 而不是 -1
correct_seqs_tensor = torch.where(mask_labels == 0, 0, correct_seqs_tensor)
skill_seqs_tensor = torch.where(mask_labels == 0, 0, skill_seqs_tensor)
mask_labels_temp = mask_labels[:, 1:]
start_timestamp_seqs_tensor = torch.where(mask_labels_temp == 0, 0, start_timestamp_seqs_tensor)
problem_seqs_tensor = torch.where(mask_labels_temp == 0, 0, problem_seqs_tensor)
seqs_length = torch.sum(mask_labels, dim=1)
# 再次检查处理后的数据
assert not torch.isnan(problem_seqs_tensor).any(), "problem_seqs_tensor contains NaNs after processing"
assert not torch.isnan(skill_seqs_tensor).any(), "skill_seqs_tensor contains NaNs after processing"
assert not torch.isnan(start_timestamp_seqs_tensor).any(), "start_timestamp_seqs_tensor contains NaNs after processing"
assert not torch.isnan(correct_seqs_tensor).any(), "correct_seqs_tensor contains NaNs after processing"
inter_embed_tensor = self.inter_embedding(skill_seqs_tensor + self.skill_num * mask_labels)
batch_size = correct_seqs_tensor.size()[0]
hidden, _ = self.continues_lstm(inter_embed_tensor, start_timestamp_seqs_tensor, seqs_length, batch_size)
hidden_packed = torch.nn.utils.rnn.pack_padded_sequence(hidden[1:, ],
seqs_length.cpu() - 1,
batch_first=False,
enforce_sorted=False)
theta = hidden_packed.data
problem_packed = torch.nn.utils.rnn.pack_padded_sequence(problem_seqs_tensor,
seqs_length.cpu() - 1,
batch_first=True,
enforce_sorted=False)
predictions = torch.squeeze(self.problem_hidden(theta, problem_packed.data))
labels_packed = torch.nn.utils.rnn.pack_padded_sequence(correct_seqs_tensor[:, 1:],
seqs_length.cpu() - 1,
batch_first=True,
enforce_sorted=False)
labels = labels_packed.data
out_dict = {'predictions': predictions, 'labels': labels}
# 检查模型输出是否包含 NaN 或 Inf
if torch.isnan(predictions).any():
print("Predictions contain NaNs")
if torch.isnan(labels).any():
print("Labels contain NaNs")
if torch.isinf(predictions).any():
print("Predictions contain Inf")
if torch.isinf(labels).any():
print("Labels contain Inf")
assert not torch.isnan(predictions).any(), "Predictions contain NaNs"
assert not torch.isnan(labels).any(), "Labels contain NaNs"
assert not torch.isinf(predictions).any(), "Predictions contain Inf"
assert not torch.isinf(labels).any(), "Labels contain Inf"
return out_dict
def continues_lstm(self, inter_embed_tensor, start_timestamp_seqs_tensor, seqs_length, batch_size):
"""
Args:
inter_embed_tensor (torch.Tensor): interrelated LSTM unit. Shape of [batch_size, seq_len, embed_size]
start_timestamp_seqs_tensor (torch.Tensor): Sequence of Students start answering time. Shape of [batch_size, seq_len-1]
seqs_length (torch.Tensor): Length of sequence. Shape of [batch_size]
batch_size (int): batch size.
Returns:
torch.Tensor: Output of LSTM.
"""
self.init_states(batch_size=batch_size)
h_list = [self.h_delay]
for t in range(max(seqs_length) - 1):
one_batch = inter_embed_tensor[:, t]
c, self.c_bar, output_t, delay_t = self.conti_lstm(one_batch, self.h_delay, self.c_delay, self.c_bar)
time_lag_batch = start_timestamp_seqs_tensor[:, t]
self.c_delay, self.h_delay = self.delay(c, self.c_bar, output_t, delay_t, time_lag_batch)
# 确保 h_delay 没有 NaN 或 Inf
self.h_delay = torch.as_tensor(self.h_delay, dtype=torch.float).to(self.device)
assert not torch.isnan(self.h_delay).any(), f"h_delay at time {t} contains NaNs"
assert not torch.isinf(self.h_delay).any(), f"h_delay at time {t} contains Inf"
h_list.append(self.h_delay)
hidden = torch.stack(h_list)
return hidden, seqs_length
def init_states(self, batch_size):
"""Initialize the state of lstm
Args:
batch_size (int): batch_size
"""
self.h_delay = torch.full((batch_size, self.hidden_size), 0.5, dtype=torch.float).to(self.device)
self.c_delay = torch.full((batch_size, self.hidden_size), 0.5, dtype=torch.float).to(self.device)
self.c_bar = torch.full((batch_size, self.hidden_size), 0.5, dtype=torch.float).to(self.device)
self.c = torch.full((batch_size, self.hidden_size), 0.5, dtype=torch.float).to(self.device)
def conti_lstm(self, one_batch_inter_embed, h_d_t, c_d_t, c_bar_t):
"""
Args:
one_batch_inter_embed (torch.Tensor): one batch of interrelated LSTM unit. Shape of [batch_size, embed_size]
h_d_t (torch.Tensor): Shape of [batch_size, embed_size]
c_d_t (torch.Tensor): Shape of [batch_size, embed_size]
c_bar_t (torch.Tensor): Shape of [batch_size, embed_size]
Returns:
torch.Tensor: Data inside LSTM
"""
input = torch.cat((one_batch_inter_embed, h_d_t), dim=1)
(i, f, z, o, i_bar, f_bar, delay) = torch.chunk(self.reclstm(input), 7, -1)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
z = torch.tanh(z)
o = torch.sigmoid(o)
i_bar = torch.sigmoid(i_bar)
f_bar = torch.sigmoid(f_bar)
delay = F.softplus(delay)
# 检查中间变量是否包含 NaN 或 Inf
assert not torch.isnan(i).any(), "i contains NaNs"
assert not torch.isnan(f).any(), "f contains NaNs"
assert not torch.isnan(z).any(), "z contains NaNs"
assert not torch.isnan(o).any(), "o contains NaNs"
assert not torch.isnan(i_bar).any(), "i_bar contains NaNs"
assert not torch.isnan(f_bar).any(), "f_bar contains NaNs"
assert not torch.isnan(delay).any(), "delay contains NaNs"
assert not torch.isinf(i).any(), "i contains Inf"
assert not torch.isinf(f).any(), "f contains Inf"
assert not torch.isinf(z).any(), "z contains Inf"
assert not torch.isinf(o).any(), "o contains Inf"
assert not torch.isinf(i_bar).any(), "i_bar contains Inf"
assert not torch.isinf(f_bar).any(), "f_bar contains Inf"
assert not torch.isinf(delay).any(), "delay contains Inf"
c_t = f * c_d_t + i * z
c_bar_t = f_bar * c_bar_t + i_bar * z
# 检查 c_t 和 c_bar_t 是否包含 NaN 或 Inf
assert not torch.isnan(c_t).any(), "c_t contains NaNs"
assert not torch.isnan(c_bar_t).any(), "c_bar_t contains NaNs"
assert not torch.isinf(c_t).any(), "c_t contains Inf"
assert not torch.isinf(c_bar_t).any(), "c_bar_t contains Inf"
return c_t, c_bar_t, o, delay
def delay(self, c, c_bar, output, delay, time_lag):
"""
Args:
c (torch.Tensor): Shape of [batch_size, embed_size]
c_bar (torch.Tensor): Shape of [batch_size, embed_size]
output (torch.Tensor): Shape of [batch_size, embed_size]
delay (torch.Tensor): Shape of [batch_size, embed_size]
time_lag (torch.Tensor): Shape of [batch_size]
Returns:
torch.Tensor: Data inside LSTM
"""
exponent = - delay * time_lag.unsqueeze(-1)
exponent = torch.clamp(exponent, min=-20, max=10) # 限制指数的范围,max 降低到 10
delta = c - c_bar
delta = torch.clamp(delta, min=-10, max=10) # 限制 delta 的范围
c_delay = c_bar + delta * torch.exp(exponent)
h_delay = output * torch.tanh(c_delay)
# 检查 c_delay 和 h_delay 是否包含 NaN 或 Inf
assert not torch.isnan(c_delay).any(), "c_delay contains NaNs"
assert not torch.isnan(h_delay).any(), "h_delay contains NaNs"
assert not torch.isinf(c_delay).any(), "c_delay contains Inf"
assert not torch.isinf(h_delay).any(), "h_delay contains Inf"
return c_delay, h_delay
def problem_hidden(self, theta, problem_data):
"""Get how well the model predicts students' responses to exercise questions
Args:
theta (torch.Tensor): Student's ability value. Shape of [exer_num, seq_len]
problem_data (torch.Tensor): The id of the exercise that the student has answered. Shape of [exer_num]
Returns:
torch.Tensor: the model predictions of students' responses to exercise questions. Shape of [exer_num, 1]
"""
problem_diff = torch.sigmoid(self.problem_diff(problem_data))
problem_disc = torch.sigmoid(self.problem_disc(problem_data))
input_x = (theta - problem_diff) * problem_disc # 移除 * 10
# 添加限制
input_x = torch.clamp(input_x, min=-10, max=10) # 根据数据情况调整范围
# 检查 input_x 是否包含 NaN 或 Inf
assert not torch.isnan(input_x).any(), "input_x contains NaNs"
assert not torch.isinf(input_x).any(), "input_x contains Inf"
# 使用 ReLU 替代 Sigmoid
input_x = self.dropout1(F.relu(self.linear1(input_x)))
# 检查 input_x 是否包含 NaN 或 Inf
assert not torch.isnan(input_x).any(), "input_x after linear1 and ReLU contains NaNs"
assert not torch.isinf(input_x).any(), "input_x after linear1 and ReLU contains Inf"
# 使用 ReLU 替代 Sigmoid
input_x = self.dropout2(F.relu(self.linear2(input_x)))
# 检查 input_x 是否包含 NaN 或 Inf
assert not torch.isnan(input_x).any(), "input_x after linear2 and ReLU contains NaNs"
assert not torch.isinf(input_x).any(), "input_x after linear2 and ReLU contains Inf"
output = self.linear3(input_x) # 移除 Sigmoid 激活
# 检查 output 是否包含 NaN 或 Inf
assert not torch.isnan(output).any(), "output contains NaNs"
assert not torch.isinf(output).any(), "output contains Inf"
# 可选:打印 output 的统计信息
print(f"Output - min: {output.min().item()}, max: {output.max().item()}")
return output
def predict(self, **kwargs):
"""A function of get how well the model predicts students' responses to exercise questions and the groundtruth
Returns:
dict: The predictions of the model and the real situation
"""
outdict = self(**kwargs)
# 在预测阶段应用 Sigmoid 激活以获取概率值
y_pd = torch.sigmoid(outdict['predictions'])
return {
'y_pd': y_pd,
'y_gt': torch.as_tensor(outdict['labels'], dtype=torch.float)
}
def get_main_loss(self, **kwargs):
"""
Returns:
dict: loss dict{'loss_main': loss_value}
"""
outdict = self(**kwargs)
predictions = outdict['predictions']
labels = outdict['labels']
labels = torch.as_tensor(labels, dtype=torch.float)
# 添加断言以确保标签在 [0, 1] 范围内
assert torch.all((labels == 0) | (labels == 1)), "标签中存在非 0 或 1 的值"
# 检查 predictions 是否包含 NaN 或 Inf
assert not torch.isnan(predictions).any(), "Predictions contain NaNs"
assert not torch.isinf(predictions).any(), "Predictions contain Inf"
loss = self.loss_function(predictions, labels)
return {
'loss_main': loss
}
def get_loss_dict(self, **kwargs):
"""
Returns:
dict: loss dict{'loss_main': loss_value}
"""
return self.get_main_loss(**kwargs)
CT_NUM修改模型后的运行结果如下:(均使用ASSIST0910数据集)
SAKT修改模型后的运行结果如下:
The text was updated successfully, but these errors were encountered:
🎄 Thank you for your advice about EduStudio.
We carefully reviewed the original implementation code and the implementation of other repository. We conclude that your suggestion is feasible, and the test results we reviewed are also similar to the ones you provided. Thus, we updated the implementation of SAKT with your suggestion.
Note that: due to the presence of multiple nn.Dropout, you can solve the overflow issue by running multiple times sometimes.
Thank you again for your valuable advice, and wish you good luck.
两个模型均使用example的run来运行,在本issue中提交两个模型的报错,均是cuda断言失败,两个模型出问题的地方都在modeltpl,但是出错的方式并不相同。
SAKT报错:
2024-12-17 23:02:29[INFO]: actual window size: 100
2024-12-17 23:02:29[INFO]: {'stu_count': 4163, 'exer_count': 153, 'real_window_size': 100}
2024-12-17 23:02:29[INFO]: TrainTPL <class 'edustudio.traintpl.general_traintpl.GeneralTrainTPL'> Started!
2024-12-17 23:02:29[INFO]: ====== [FOLD ID]: 0 ======
2024-12-17 23:02:29[INFO]: [CALLBACK]-ModelCheckPoint has been registered!
2024-12-17 23:02:29[INFO]: [CALLBACK]-EarlyStopping has been registered!
2024-12-17 23:02:29[INFO]: [CALLBACK]-History has been registered!
2024-12-17 23:02:29[INFO]: [CALLBACK]-BaseLogger has been registered!
2024-12-17 23:02:29[INFO]: Start Training...
[EPOCH=001]: 0%| | 0/123 [00:00<?, ?it/s]C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [0,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [1,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [2,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [3,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [4,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [5,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [6,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [7,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [8,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [9,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [10,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [11,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [12,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [13,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [14,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [15,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [16,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [3,0,0], thread: [17,0,0] Assertion input_val >= zero && input_val <= one failed.
[EPOCH=001]: 0%| | 0/123 [00:01<?, ?it/s]
Traceback (most recent call last):
File "C:/Users/Akkzzzz/Desktop/EDU/EduStudio-main/examples/single_model/run_sakt_demo.py", line 9, in
run_edustudio(
File "C:\Users\Akkzzzz\Desktop\EDU\EduStudio-main\edustudio\quickstart\quickstart.py", line 72, in run_edustudio
raise e
File "C:\Users\Akkzzzz\Desktop\EDU\EduStudio-main\edustudio\quickstart\quickstart.py", line 58, in run_edustudio
traintpl.start()
File "C:\Users\Akkzzzz\Desktop\EDU\EduStudio-main\edustudio\traintpl\gd_traintpl.py", line 79, in start
metrics = self.one_fold_start(fold_id)
File "C:\Users\Akkzzzz\Desktop\EDU\EduStudio-main\edustudio\traintpl\general_traintpl.py", line 53, in one_fold_start
self.fit(train_loader=self.train_loader, valid_loader=self.valid_loader)
File "C:\Users\Akkzzzz\Desktop\EDU\EduStudio-main\edustudio\traintpl\general_traintpl.py", line 99, in fit
loss.backward()
File "D:\ANAA\envs\EduStudio\lib\site-packages\torch_tensor.py", line 487, in backward
torch.autograd.backward(
File "D:\ANAA\envs\EduStudio\lib\site-packages\torch\autograd_init_.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
binary_cross_entropy 函数的输入值 (y_pd) 超出了 [0, 1] 的预期范围(损失函数 binary_cross_entropy 的输入值超出了 [0, 1] 的范围),经过系统的分析,因为traintpl只管训练循环,所以问题应该出在问题出在ModelTPL环节,打开sakt.py
为了解决该问题,我移除 forward 方法中的 sigmoid 激活函数:改为在损失函数中使用 BCEWithLogitsLoss,它会自动应用 sigmoid,从而提高数值稳定性,并在损失计算中使用 BCEWithLogitsLoss:替代 binary_cross_entropy
修正代码如下(不保证正确,修正代码后运行结果我会附图)
CT_NUM报错:
2024-12-17 19:29:24[INFO]: actual window size: 100
2024-12-17 19:29:26[INFO]: {'real_window_size': 100, 'stu_count': 4163, 'exer_count': 16988, 'cpt_count': 122}
2024-12-17 19:29:26[INFO]: TrainTPL <class 'edustudio.traintpl.general_traintpl.GeneralTrainTPL'> Started!
2024-12-17 19:29:26[INFO]: ====== [FOLD ID]: 0 ======
[EPOCH=001]: 0%| | 0/133 [00:00<?, ?it/s]2024-12-17 19:29:26[INFO]: [CALLBACK]-ModelCheckPoint has been registered!
2024-12-17 19:29:26[INFO]: [CALLBACK]-EarlyStopping has been registered!
2024-12-17 19:29:26[INFO]: [CALLBACK]-History has been registered!
2024-12-17 19:29:26[INFO]: [CALLBACK]-BaseLogger has been registered!
2024-12-17 19:29:26[INFO]: Start Training...
[EPOCH=001]: 93%|████████████████████████████████████████████▊ | 124/133 [01:27<00:05, 1.58it/s]C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [1,0,0], thread: [33,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [0,0,0], thread: [36,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [0,0,0], thread: [104,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [1,0,0], thread: [68,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [1,0,0], thread: [73,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [0,0,0], thread: [64,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [0,0,0], thread: [91,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [0,0,0], thread: [28,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [1,0,0], thread: [12,0,0] Assertion input_val >= zero && input_val <= one failed.
C:\actions-runner_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\Loss.cu:103: block: [1,0,0], thread: [28,0,0] Assertion input_val >= zero && input_val <= one failed.
[EPOCH=001]: 93%|████████████████████████████████████████████▊ | 124/133 [01:29<00:06, 1.39it/s]
2024-12-17 19:30:55[ERROR]: Traceback (most recent call last):
File "C:\Users\Akkzzzz\Desktop\EDU\EduStudio-main\edustudio\quickstart\quickstart.py", line 58, in run_edustudio
traintpl.start()
File "C:\Users\Akkzzzz\Desktop\EDU\EduStudio-main\edustudio\traintpl\gd_traintpl.py", line 79, in start
metrics = self.one_fold_start(fold_id)
File "C:\Users\Akkzzzz\Desktop\EDU\EduStudio-main\edustudio\traintpl\general_traintpl.py", line 53, in one_fold_start
self.fit(train_loader=self.train_loader, valid_loader=self.valid_loader)
File "C:\Users\Akkzzzz\Desktop\EDU\EduStudio-main\edustudio\traintpl\general_traintpl.py", line 96, in fit
loss_dict = self.model.get_loss_dict(**batch_dict)
File "C:\Users\Akkzzzz\Desktop\EDU\EduStudio-main\edustudio\model\KT\ct_ncm.py", line 260, in get_loss_dict
return self.get_main_loss(**kwargs)
File "C:\Users\Akkzzzz\Desktop\EDU\EduStudio-main\edustudio\model\KT\ct_ncm.py", line 249, in get_main_loss
loss = self.loss_function(predictions, labels)
File "D:\ANAA\envs\EduStudio\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "D:\ANAA\envs\EduStudio\lib\site-packages\torch\nn\modules\loss.py", line 619, in forward
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
File "D:\ANAA\envs\EduStudio\lib\site-packages\torch\nn\functional.py", line 3098, in binary_cross_entropy
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
RuntimeError: CUDA error: device-side assert triggered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
错误分析:输入给 binary_cross_entropy 函数的值 (input_val) 有些超出了预期的范围 [0, 1]。
打开ModelTPL环节可以发现,在forward方法中,correct_seqs_tensor = torch.where(mask_labels == 0, -1, correct_seqs_tensor),把标签设置为-1导致了cuda在计算损失的时候断言失败。修正后又出现了数值溢出的问题,于是我在delay方法中,限制输入范围以确保不会数值溢出,并像上文修复sakt模型一样更换了计算损失的方法。
修正代码如下(不保证正确,跑出后的结果我会附图):r"""
CT_NCM
##################################
Reference:
Haiping Ma et al. "Reconciling cognitive modeling with knowledge forgetting: A continuous time-aware neural network approach." in IJCAI 2022.
Reference code:
https://github.com/BIMK/Intelligent-Education/tree/main/CTNCM
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..gd_basemodel import GDBaseModel
class CT_NCM(GDBaseModel):
"""
hidden_size: dimensions of LSTM hidden layers
embed_size: dimensions of student-knowledge concept interaction embedding
prelen1: the first layer of performance prediction
prelen2: the second layer of performance prediction
dropout1: the proportion of first fully connected layer dropout before getting the prediction score
dropout2: the proportion of second fully connected layer dropout before getting the prediction score
"""
default_cfg = {
'hidden_size': 64,
'embed_size': 64,
'prelen1': 256,
'prelen2': 128,
'dropout1': 0,
'dropout2': 0,
}
CT_NUM修改模型后的运行结果如下:(均使用ASSIST0910数据集)

SAKT修改模型后的运行结果如下:

The text was updated successfully, but these errors were encountered: