Skip to content

Commit 0fe099a

Browse files
authored
python pyi print "import datetime" for Duration/Timestamp field (#21885)
PiperOrigin-RevId: 761639872
1 parent f156008 commit 0fe099a

File tree

1 file changed

+33
-17
lines changed

1 file changed

+33
-17
lines changed

src/google/protobuf/compiler/python/pyi_generator.cc

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ struct ImportModules {
7373
bool has_union = false; // typing.Union
7474
bool has_callable = false; // typing.Callable
7575
bool has_well_known_type = false;
76+
bool has_datetime = false;
7677
};
7778

7879
// Checks whether a descriptor name matches a well-known type.
@@ -112,8 +113,15 @@ void CheckImportModules(const Descriptor* descriptor,
112113
if (field->is_map()) {
113114
import_modules->has_mapping = true;
114115
const FieldDescriptor* value_des = field->message_type()->field(1);
115-
if (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE ||
116-
value_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
116+
if (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
117+
import_modules->has_union = true;
118+
const absl::string_view name = value_des->message_type()->full_name();
119+
if (name == "google.protobuf.Duration" ||
120+
name == "google.protobuf.Timestamp") {
121+
import_modules->has_datetime = true;
122+
}
123+
}
124+
if (value_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
117125
import_modules->has_union = true;
118126
}
119127
} else {
@@ -123,6 +131,11 @@ void CheckImportModules(const Descriptor* descriptor,
123131
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
124132
import_modules->has_union = true;
125133
import_modules->has_mapping = true;
134+
const absl::string_view name = field->message_type()->full_name();
135+
if (name == "google.protobuf.Duration" ||
136+
name == "google.protobuf.Timestamp") {
137+
import_modules->has_datetime = true;
138+
}
126139
}
127140
if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
128141
import_modules->has_union = true;
@@ -170,21 +183,6 @@ void PyiGenerator::PrintImportForDescriptor(
170183
}
171184

172185
void PyiGenerator::PrintImports() const {
173-
// Prints imported dependent _pb2 files.
174-
absl::flat_hash_set<std::string> seen_aliases;
175-
bool has_importlib = false;
176-
for (int i = 0; i < file_->dependency_count(); ++i) {
177-
const FileDescriptor* dep = file_->dependency(i);
178-
if (strip_nonfunctional_codegen_ && IsKnownFeatureProto(dep->name())) {
179-
continue;
180-
}
181-
PrintImportForDescriptor(*dep, &seen_aliases, &has_importlib);
182-
for (int j = 0; j < dep->public_dependency_count(); ++j) {
183-
PrintImportForDescriptor(*dep->public_dependency(j), &seen_aliases,
184-
&has_importlib);
185-
}
186-
}
187-
188186
// Checks what modules should be imported.
189187
ImportModules import_modules;
190188
if (file_->message_type_count() > 0) {
@@ -201,6 +199,24 @@ void PyiGenerator::PrintImports() const {
201199
for (int i = 0; i < file_->message_type_count(); i++) {
202200
CheckImportModules(file_->message_type(i), &import_modules);
203201
}
202+
if (import_modules.has_datetime) {
203+
printer_->Print("import datetime\n\n");
204+
}
205+
206+
// Prints imported dependent _pb2 files.
207+
absl::flat_hash_set<std::string> seen_aliases;
208+
bool has_importlib = false;
209+
for (int i = 0; i < file_->dependency_count(); ++i) {
210+
const FileDescriptor* dep = file_->dependency(i);
211+
if (strip_nonfunctional_codegen_ && IsKnownFeatureProto(dep->name())) {
212+
continue;
213+
}
214+
PrintImportForDescriptor(*dep, &seen_aliases, &has_importlib);
215+
for (int j = 0; j < dep->public_dependency_count(); ++j) {
216+
PrintImportForDescriptor(*dep->public_dependency(j), &seen_aliases,
217+
&has_importlib);
218+
}
219+
}
204220

205221
// Prints modules (e.g. _containers, _messages, typing) that are
206222
// required in the proto file.

0 commit comments

Comments
 (0)