1 | # coding=utf-8 |
1 | import tensorflow as tf |
1 | # 读取数据并处理(测试集验证集) |
1 | # 定义模型参数(一层隐藏层) |
1 | # 定义激活函数 ReLu softmax |
1 | # dropout(H, drop_prob) |
1 | # 定义损失函数 交叉熵 L2正则项 |
1 | # 定义get_K_fold_data函数 |
1 | # 定义训练函数 |
1 | # 预测函数 predict() |
1 | params = [W1, b1, W2, b2] |
1 | result = predict(net, params, X_test) |