提交 b1da6442 作者: 杨锋

bug fixed

上级 d50d1fa4
......@@ -190,7 +190,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 36%|███▌ | 18117/50016 [00:04<00:07, 4136.26it/s]"
"100%|██████████| 50016/50016 [00:12<00:00, 4000.53it/s]\n"
]
}
],
......@@ -200,7 +200,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"id": "829443c8-5ebc-4d47-bbca-5c554012bb70",
"metadata": {},
"outputs": [],
......@@ -210,7 +210,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"id": "2458f247-ab5c-4ca5-b61e-1adb766a0ea5",
"metadata": {},
"outputs": [],
......@@ -220,7 +220,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"id": "829a1b14-b8ad-4d5a-b5c9-40ac2e184162",
"metadata": {},
"outputs": [],
......@@ -248,7 +248,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"id": "bfcdf3a6-177a-4e7a-b2c8-6d0f8222e42e",
"metadata": {},
"outputs": [],
......@@ -258,7 +258,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 18,
"id": "c8b89c38-bd51-40b0-a964-0df39d7c38cc",
"metadata": {},
"outputs": [],
......@@ -280,7 +280,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 19,
"id": "5b1ac171-8aaa-426d-b717-3bfbcade1564",
"metadata": {},
"outputs": [],
......@@ -296,7 +296,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 20,
"id": "95cc40b8-abf0-4d4c-9f8c-2948f06553b8",
"metadata": {
"tags": []
......@@ -328,7 +328,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 21,
"id": "1a02bbcc-797c-4f50-880b-445837cf3f71",
"metadata": {},
"outputs": [],
......@@ -352,7 +352,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 22,
"id": "6707968e-e8af-41f5-bb8a-26b6eee1e576",
"metadata": {},
"outputs": [],
......@@ -376,7 +376,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 23,
"id": "411ff8c9-9c0d-46c6-98ac-7af10e0a9ddf",
"metadata": {},
"outputs": [],
......@@ -391,20 +391,20 @@
" \n",
" \n",
" self.convs_sequence = nn.ModuleList([\n",
" nn.Sequential(nn.Conv1d(in_channels=48, out_channels=24, kernel_size=h),\n",
" nn.BatchNorm1d(num_features=config.out_channels),\n",
" nn.Sigmoid(),\n",
" nn.MaxPool1d(kernel_size=config.feature_size - h)\n",
" nn.Sequential(nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3, stride=1),\n",
" nn.BatchNorm1d(num_features=1),\n",
" nn.Sigmoid()\n",
" )\n",
" for h in range(5,100,5)])\n",
" for h in range(5,config.feature_size,5)])\n",
" self.convs_finterprint = nn.ModuleList([\n",
" nn.Sequential(nn.Conv1d(in_channels=1, out_channels=40, kernel_size=h),\n",
" \n",
" nn.BatchNorm1d(num_features=40),\n",
" nn.Sigmoid(),\n",
" nn.MaxPool1d(kernel_size=config.fingerprint_len - h)\n",
" nn.MaxPool1d(kernel_size=config.fingerprint_len - h),\n",
" )\n",
" for h in range(32,2048,32)])\n",
" self.fc = nn.Linear(in_features=2976, out_features=config.num_class)\n",
" for h in range(32,config.fingerprint_len,32)])\n",
" self.fc = nn.Linear(in_features=3700, out_features=config.num_class)\n",
" \n",
" def forward(self, x):\n",
" split_index = config.feature_size*config.in_channels\n",
......@@ -412,7 +412,6 @@
" fingerprint_products = x[:,split_index:].reshape(x.shape[0],1,config.fingerprint_len)\n",
" \n",
" \n",
" \n",
" out_encoded = [conv(encoded_products) for conv in self.convs_sequence]\n",
" out_encoded = torch.cat(out_encoded, dim=1)\n",
" out_encoded = out_encoded.view(-1, out_encoded.size(1)) \n",
......@@ -435,12 +434,40 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 24,
"id": "72ab74b2-630f-4161-a332-2245bf13e22e",
"metadata": {
"tags": []
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/anaconda3/lib/python3.8/site-packages/torch/cuda/__init__.py:145: UserWarning: \n",
"NVIDIA GeForce RTX 3090 with CUDA capability sm_86 is not compatible with the current PyTorch installation.\n",
"The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70.\n",
"If you want to use the NVIDIA GeForce RTX 3090 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/\n",
"\n",
" warnings.warn(incompatible_device_warn.format(device_name, capability, \" \".join(arch_list), device_name))\n",
" 0%| | 0/100 [00:00<?, ?it/s]\n"
]
},
{
"ename": "RuntimeError",
"evalue": "shape '[128, 1, 2048]' is invalid for input of size 159744",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-24-09e96cf4d6c3>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0mtrain_accuracy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_dataloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mval_accuracy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_dataloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-20-7ba44751ddec>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(dataloader, model, optimizer, loss_fn)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# Compute prediction error\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;31m# print(pred)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;31m# print(y)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1108\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1111\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1112\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-23-e10e7a067e38>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0msplit_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_size\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0min_channels\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mencoded_products\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0msplit_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0min_channels\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0mfingerprint_products\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msplit_index\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfingerprint_len\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: shape '[128, 1, 2048]' is invalid for input of size 159744"
]
}
],
"source": [
"result_list = pd.DataFrame()\n",
"\n",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论