@@ -1033,6 +1033,74 @@ namespace deepx::tf
1033
1033
return 0 ;
1034
1034
}
1035
1035
};
1036
+
1037
+ // dropout
1038
+ template <typename Author>
1039
+ class Dropout : public TF
1040
+ {
1041
+ public:
1042
+ Dropout (const vector<Param> &args, const vector<Param> &returns)
1043
+ {
1044
+ this ->name = " dropout" ;
1045
+ this ->metadata .author = Author::name ();
1046
+ this ->tftype = " elementwise" ;
1047
+ this ->args = args;
1048
+ this ->returns = returns;
1049
+ }
1050
+ string math_formula () const override
1051
+ {
1052
+ return " T1.dropout(p,seed)->T3" ;
1053
+ }
1054
+ shared_ptr<TF> clone () const override
1055
+ {
1056
+ return make_shared<Dropout<Author>>(*this );
1057
+ }
1058
+ int run (shared_ptr<MemBase> mem, string &error) override
1059
+ {
1060
+ if (!checktensors ({this ->args [0 ].textvalue , this ->returns [0 ].textvalue }, mem, error))
1061
+ {
1062
+ return 1 ;
1063
+ }
1064
+ Precision a_type = mem->gettensor (this ->args [0 ].textvalue ).get ()->shape .dtype ;
1065
+ Precision c_type = mem->gettensor (this ->returns [0 ].textvalue ).get ()->shape .dtype ;
1066
+ if (a_type != c_type)
1067
+ {
1068
+ error = " Type mismatch: " + precision_str (a_type) + " != " + precision_str (c_type);
1069
+ return 1 ;
1070
+ }
1071
+ switch (a_type)
1072
+ {
1073
+ case Precision::Float64 :
1074
+ tensorfunc::dropout<Author>(*mem->gettensor <double >(this ->args [0 ].textvalue ), this ->getvar <float >(1 , mem), this ->getvar <unsigned int >(2 , mem), *mem->gettensor <double >(this ->returns [0 ].textvalue ));
1075
+ break ;
1076
+ case Precision::Float32 :
1077
+ tensorfunc::dropout<Author>(*mem->gettensor <float >(this ->args [0 ].textvalue ), this ->getvar <float >(1 , mem), this ->getvar <unsigned int >(2 , mem), *mem->gettensor <float >(this ->returns [0 ].textvalue ));
1078
+ break ;
1079
+ case Precision::Float16:
1080
+ tensorfunc::dropout<Author>(*mem->gettensor <half>(this ->args [0 ].textvalue ), this ->getvar <float >(1 , mem), this ->getvar <unsigned int >(2 , mem), *mem->gettensor <half>(this ->returns [0 ].textvalue ));
1081
+ break ;
1082
+ case Precision::BFloat16:
1083
+ tensorfunc::dropout<Author>(*mem->gettensor <nv_bfloat16>(this ->args [0 ].textvalue ), this ->getvar <float >(1 , mem), this ->getvar <unsigned int >(2 , mem), *mem->gettensor <nv_bfloat16>(this ->returns [0 ].textvalue ));
1084
+ break ;
1085
+ case Precision::Int64:
1086
+ tensorfunc::dropout<Author>(*mem->gettensor <int64_t >(this ->args [0 ].textvalue ), this ->getvar <float >(1 , mem), this ->getvar <unsigned int >(2 , mem), *mem->gettensor <int64_t >(this ->returns [0 ].textvalue ));
1087
+ break ;
1088
+ case Precision::Int32:
1089
+ tensorfunc::dropout<Author>(*mem->gettensor <int32_t >(this ->args [0 ].textvalue ), this ->getvar <float >(1 , mem), this ->getvar <unsigned int >(2 , mem), *mem->gettensor <int32_t >(this ->returns [0 ].textvalue ));
1090
+ break ;
1091
+ case Precision::Int16:
1092
+ tensorfunc::dropout<Author>(*mem->gettensor <int16_t >(this ->args [0 ].textvalue ), this ->getvar <float >(1 , mem), this ->getvar <unsigned int >(2 , mem), *mem->gettensor <int16_t >(this ->returns [0 ].textvalue ));
1093
+ break ;
1094
+ case Precision::Int8:
1095
+ tensorfunc::dropout<Author>(*mem->gettensor <int8_t >(this ->args [0 ].textvalue ), this ->getvar <float >(1 , mem), this ->getvar <unsigned int >(2 , mem), *mem->gettensor <int8_t >(this ->returns [0 ].textvalue ));
1096
+ break ;
1097
+ default :
1098
+ error = " Unsupported dtype: " + precision_str (a_type);
1099
+ return 1 ;
1100
+ }
1101
+ return 0 ;
1102
+ }
1103
+ };
1036
1104
};
1037
1105
1038
1106
#endif // DEEPX_TF_ELEMENTWISE_BASIC_HPP
0 commit comments