From 22adf62fb636e2456b79bf3cf5725cef2a2ab263 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 2 Jul 2026 03:07:19 +0000 Subject: [PATCH 01/19] refactor(http-client-csharp): replace Roslyn reference map analysis Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/Providers/ClientProvider.cs | 77 + .../Providers/ClientUriBuilderDefinition.cs | 1 + .../Providers/CollectionResultDefinition.cs | 16 + .../MrwSerializationTypeDefinition.Xml.cs | 6 +- .../MrwSerializationTypeDefinition.cs | 71 +- ...ultipartFormDataSerializationDefinition.cs | 6 + .../src/Providers/RestClientProvider.cs | 146 +- .../SerializationFormatDefinition.cs | 2 + .../src/Snippets/HttpRequestApiSnippets.cs | 4 +- ...ClientBodyDependencyPostProcessingTests.cs | 660 +++++ .../SystemObjectModelSerializationTests.cs | 34 - .../RestClientProviderTests.cs | 32 - .../test/TestHelpers/MockHelpers.cs | 17 +- .../src/InputTypes/InputModelTypeUsage.cs | 1 - .../test/TypeSpecInputConverterTests.cs | 36 - .../src/CSharpGen.cs | 41 +- .../src/LibraryVisitor.cs | 4 +- .../PostProcessing/GeneratedCodeWorkspace.cs | 12 +- .../src/PostProcessing/PostProcessor.cs | 177 +- .../ProviderReferenceMapAnalyzer.cs | 2222 +++++++++++++++++ .../ProviderReferenceMapResult.cs | 14 + .../src/Primitives/TypeProviderWriter.cs | 2 +- .../src/Providers/NamedTypeSymbolProvider.cs | 177 ++ .../src/Providers/TypeProvider.cs | 89 +- .../src/SourceInput/SourceInputModel.cs | 26 + .../src/TypeFactory.cs | 13 +- .../test/OutputLibraryVisitorTests.cs | 4 +- .../test/PostProcessing/PostProcessorTests.cs | 40 +- .../RootClient.cs | 15 + .../ModelFactoriesCustomizationTests.cs | 26 + .../ModelFactoryProviderTests.cs | 79 - .../DerivedModel.cs | 9 + .../ClientCustomizationTests.cs | 2 +- .../ModelProviders/ModelProviderTests.cs | 4 +- .../Generated/DocumentationModelFactory.cs | 1 - .../Generated/ParametersBasicModelFactory.cs | 1 - .../Generated/ParametersSpreadModelFactory.cs | 2 - .../Generated/PayloadMultiPartModelFactory.cs | 1 - .../Generated/PayloadPageableModelFactory.cs | 6 - .../src/Generated/SpecialWordsModelFactory.cs | 2 - 40 files changed, 3557 insertions(+), 521 deletions(-) create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/TestData/PostProcessorTests/RemovesInvalidUsingsKeepsFileHeader/RootClient.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs index f8ee8744e32..55fee64b30f 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs @@ -43,6 +43,7 @@ private record ApiVersionFields(FieldProvider Field, PropertyProvider? Correspon private const string ClientSuffix = "Client"; private readonly FormattableString _publicCtorDescription; private readonly InputClient _inputClient; + protected override bool IsClientProvider => true; internal InputClient InputClient => _inputClient; private readonly InputAuth? _inputAuth; private readonly ParameterProvider _endpointParameter; @@ -426,6 +427,82 @@ private IReadOnlyList GetClientParameters() protected override string BuildName() => _inputClient.IsExactName ? _inputClient.Name : _inputClient.Name.ToIdentifierName(); + protected override IReadOnlyList BuildHelperDependencyTypes() + { + foreach (var method in Methods.OfType()) + { + if (method.BodyStatements != null) + { + return [new CancellationTokenExtensionsDefinition().Type, new ClientPipelineExtensionsDefinition().Type]; + } + } + + return []; + } + + protected override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new List(); + foreach (var method in Methods.OfType()) + { + if (method.BodyStatements == null) + { + continue; + } + + if (method.CollectionDefinition != null) + { + dependencies.Add(method.CollectionDefinition.Type); + } + + if (method.ServiceMethod == null) + { + continue; + } + + AddInputTypeDependency(dependencies, method.ServiceMethod.Response.Type); + AddInputTypeDependency(dependencies, method.ServiceMethod.Exception?.Type); + foreach (var parameter in method.ServiceMethod.Parameters) + { + if (IsContentTypeParameter(parameter)) + { + continue; + } + + AddInputTypeDependency(dependencies, parameter.Type); + } + + foreach (var parameter in method.ServiceMethod.Operation.Parameters) + { + if (IsContentTypeParameter(parameter)) + { + continue; + } + + AddInputTypeDependency(dependencies, parameter.Type); + } + + // Operation responses are input metadata. The generated method signature and body + // dependencies above capture the response types that are actually used. + } + + return dependencies; + } + + private static bool IsContentTypeParameter(InputParameter parameter) => + parameter is InputHeaderParameter { IsContentType: true } || + parameter is InputMethodParameter { Location: InputRequestLocation.Header } && + string.Equals(parameter.SerializedName, "Content-Type", StringComparison.OrdinalIgnoreCase); + + private static void AddInputTypeDependency(List dependencies, InputType? inputType) + { + var type = inputType == null ? null : ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(inputType); + if (type != null) + { + dependencies.Add(type); + } + } + protected override FieldProvider[] BuildFields() { List fields = [EndpointField]; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientUriBuilderDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientUriBuilderDefinition.cs index 92e0cf23a67..4dc5c283cc7 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientUriBuilderDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientUriBuilderDefinition.cs @@ -28,6 +28,7 @@ internal sealed class ClientUriBuilderDefinition : TypeProvider private readonly FieldProvider _uriBuilderField; private readonly FieldProvider _pathAndQueryField; private readonly FieldProvider _pathLengthField; + protected override bool IncludeGeneratedBodyReferences => true; private PropertyProvider? _uriBuilderProperty; private PropertyProvider UriBuilderProperty => _uriBuilderProperty ??= new( diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs index ae617957bf5..590eaf2b935 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs @@ -217,6 +217,22 @@ private bool HasPagingOperationNameCollision(string operationName) protected override TypeSignatureModifiers BuildDeclarationModifiers() => TypeSignatureModifiers.Internal | TypeSignatureModifiers.Partial | TypeSignatureModifiers.Class; + protected override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new List { Client.Type, ResponseModelType, NextPagePropertyType }; + if (ItemModelType != null) + { + dependencies.Add(ItemModelType); + } + + foreach (var field in RequestFields) + { + dependencies.Add(field.Type); + } + + return dependencies; + } + protected override FieldProvider[] BuildFields() => [ClientField, .. RequestFields]; protected override CSharpType[] BuildImplements() => diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs index 5d76b8f44a1..2b50fff372e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs @@ -67,7 +67,7 @@ private MethodProvider BuildXmlModelWriteCoreMethod() MethodSignatureModifiers modifiers = _isStruct ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Internal | MethodSignatureModifiers.Virtual; - if (_shouldOverrideXmlMethods) + if (_shouldOverrideMethods) { modifiers = MethodSignatureModifiers.Internal | MethodSignatureModifiers.Override; } @@ -81,7 +81,7 @@ private MethodProvider BuildXmlModelWriteCoreMethod() private MethodBodyStatement[] BuildXmlModelWriteCoreMethodBody() { - var categorizedProperties = _shouldOverrideXmlMethods + var categorizedProperties = _shouldOverrideMethods ? CategorizedXmlProperties : AllCategorizedXmlProperties; var statements = new List @@ -90,7 +90,7 @@ private MethodBodyStatement[] BuildXmlModelWriteCoreMethodBody() MethodBodyStatement.EmptyLine }; - if (_shouldOverrideXmlMethods) + if (_shouldOverrideMethods) { statements.Add(Base.Invoke(XmlModelWriteCoreMethodName, _xmlWriterParameter, _serializationOptionsParameter).Terminate()); } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index 8202a2af405..97532886c40 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -53,16 +53,10 @@ public partial class MrwSerializationTypeDefinition : TypeProvider private readonly ScopedApi _mrwOptionsParameterSnippet; private readonly ScopedApi _jsonElementParameterSnippet; private readonly ScopedApi _isNotEqualToWireConditionSnippet; - // These interface types depend on _model.Type. Build them lazily so we do not cache a - // CSharpType before delayed base model resolution has updated the model's inheritance. - private CSharpType? _jsonModelTInterfaceValue; - private CSharpType _jsonModelTInterface => _jsonModelTInterfaceValue ??= new CSharpType(typeof(IJsonModel<>), SerializationInterfaceType.Type); - private CSharpType? _jsonModelObjectInterface; - private CSharpType? JsonModelObjectInterface => _isStruct ? _jsonModelObjectInterface ??= (CSharpType)typeof(IJsonModel) : null; - private CSharpType? _persistableModelTInterfaceValue; - private CSharpType _persistableModelTInterface => _persistableModelTInterfaceValue ??= new CSharpType(typeof(IPersistableModel<>), SerializationInterfaceType.Type); - private CSharpType? _persistableModelObjectInterface; - private CSharpType? PersistableModelObjectInterface => _isStruct ? _persistableModelObjectInterface ??= (CSharpType)typeof(IPersistableModel) : null; + private readonly CSharpType _jsonModelTInterface; + private readonly CSharpType? _jsonModelObjectInterface; + private readonly CSharpType _persistableModelTInterface; + private readonly CSharpType? _persistableModelObjectInterface; private readonly ModelProvider _model; private readonly InputModelType _inputModel; private readonly FieldProvider? _rawDataField; @@ -73,20 +67,10 @@ public partial class MrwSerializationTypeDefinition : TypeProvider private readonly bool _supportsXml; private ConstructorProvider? _serializationConstructor; // Flag to determine if the model should override the serialization methods - private bool? _shouldOverrideMethods; - private bool ShouldOverrideMethods => _shouldOverrideMethods ??= _model.BaseModelProvider != null && !_isStruct; - private bool? _shouldSkipSerializationMethodOverrides; - private bool ShouldSkipSerializationMethodOverrides => _shouldSkipSerializationMethodOverrides ??= ShouldSkipDerivedSerializationMethodOverrides(_model.BaseModelProvider); - private readonly bool _shouldOverrideXmlMethods; + private readonly bool _shouldOverrideMethods; + private readonly bool _shouldSkipDerivedSerializationMethodOverrides; private readonly Lazy _additionalProperties; - // Unknown discriminator models use their base model as the serialization interface type. - // This can also touch model.Type, so defer it until serialization method/interface emission. - private TypeProvider SerializationInterfaceType => _serializationInterfaceType ??= _inputModel.IsUnknownDiscriminatorModel - ? ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(_inputModel.BaseModel!)! - : _model; - private TypeProvider? _serializationInterfaceType; - private CSharpType RootType => _rootType ??= GetRootModelType(); private CSharpType? _rootType; @@ -100,10 +84,17 @@ public MrwSerializationTypeDefinition(InputModelType inputModel, ModelProvider m _isStruct = _model.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Struct); _supportsXml = inputModel.Usage.HasFlag(InputModelTypeUsage.Xml); _supportsJson = inputModel.Usage.HasFlag(InputModelTypeUsage.Json) || !_supportsXml; - _shouldOverrideXmlMethods = _model.BaseModelProvider != null && !_isStruct; + // Initialize the serialization interfaces + var interfaceType = inputModel.IsUnknownDiscriminatorModel ? ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(inputModel.BaseModel!)! : _model; + _jsonModelTInterface = new CSharpType(typeof(IJsonModel<>), interfaceType.Type); + _jsonModelObjectInterface = _isStruct ? (CSharpType)typeof(IJsonModel) : null; + _persistableModelTInterface = new CSharpType(typeof(IPersistableModel<>), interfaceType.Type); + _persistableModelObjectInterface = _isStruct ? (CSharpType)typeof(IPersistableModel) : null; _rawDataField = _model.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); _additionalBinaryDataProperty = new(GetAdditionalBinaryDataPropertiesProp); _additionalProperties = new(() => [.. _model.Properties.Where(p => p.IsAdditionalProperties)]); + _shouldOverrideMethods = _model.BaseModelProvider != null && !_isStruct; + _shouldSkipDerivedSerializationMethodOverrides = ShouldSkipDerivedSerializationMethodOverrides(_model.BaseModelProvider); _utf8JsonWriterSnippet = _utf8JsonWriterParameter.As(); _mrwOptionsParameterSnippet = _serializationOptionsParameter.As(); _jsonElementParameterSnippet = _jsonElementDeserializationParam.As(); @@ -126,6 +117,10 @@ public MrwSerializationTypeDefinition(InputModelType inputModel, ModelProvider m protected override CSharpType? BuildBaseType() => _model.BaseType; + protected override IReadOnlyList BuildHelperDependencyTypes() => _rawDataField != null || _additionalProperties.Value.Length > 0 + ? [ScmCodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType] + : []; + protected override SuppressionStatement[] BuildDisabledFileWarnings() { if (_model.CanonicalView.Properties.Any(p => ScmModelProvider.IsFileBinaryContentType(p.Type))) @@ -438,19 +433,17 @@ protected override CSharpType[] BuildImplements() if (_supportsJson) { interfaces.Add(_jsonModelTInterface); - var jsonModelObjectInterface = JsonModelObjectInterface; - if (jsonModelObjectInterface != null) + if (_jsonModelObjectInterface != null) { - interfaces.Add(jsonModelObjectInterface); + interfaces.Add(_jsonModelObjectInterface); } } else if (_supportsXml) { interfaces.Add(_persistableModelTInterface); - var persistableModelObjectInterface = PersistableModelObjectInterface; - if (persistableModelObjectInterface != null) + if (_persistableModelObjectInterface != null) { - interfaces.Add(persistableModelObjectInterface); + interfaces.Add(_persistableModelObjectInterface); } } @@ -480,7 +473,7 @@ internal MethodProvider BuildJsonModelWriteMethodObjectDeclaration() var castToT = This.CastTo(_jsonModelTInterface); return new MethodProvider ( - new MethodSignature(nameof(IJsonModel.Write), null, MethodSignatureModifiers.None, null, null, [_utf8JsonWriterParameter, _serializationOptionsParameter], ExplicitInterface: JsonModelObjectInterface), + new MethodSignature(nameof(IJsonModel.Write), null, MethodSignatureModifiers.None, null, null, [_utf8JsonWriterParameter, _serializationOptionsParameter], ExplicitInterface: _jsonModelObjectInterface), castToT.Invoke(nameof(IJsonModel.Write), [_utf8JsonWriterParameter, _serializationOptionsParameter]), this ); @@ -495,7 +488,7 @@ internal MethodProvider BuildJsonModelCreateMethodObjectDeclaration() var castToT = This.CastTo(_jsonModelTInterface); return new MethodProvider ( - new MethodSignature(nameof(IJsonModel.Create), null, MethodSignatureModifiers.None, typeof(object), null, [_utf8JsonReaderParameter, _serializationOptionsParameter], ExplicitInterface: JsonModelObjectInterface), + new MethodSignature(nameof(IJsonModel.Create), null, MethodSignatureModifiers.None, typeof(object), null, [_utf8JsonReaderParameter, _serializationOptionsParameter], ExplicitInterface: _jsonModelObjectInterface), castToT.Invoke(nameof(IJsonModel.Create), [_utf8JsonReaderParameter.AsArgument(), _serializationOptionsParameter]), this ); @@ -511,7 +504,7 @@ internal MethodProvider BuildPersistableModelWriteMethodObjectDeclaration() var returnType = typeof(BinaryData); return new MethodProvider ( - new MethodSignature(nameof(IPersistableModel.Write), null, MethodSignatureModifiers.None, returnType, null, [_serializationOptionsParameter], ExplicitInterface: PersistableModelObjectInterface), + new MethodSignature(nameof(IPersistableModel.Write), null, MethodSignatureModifiers.None, returnType, null, [_serializationOptionsParameter], ExplicitInterface: _persistableModelObjectInterface), castToT.Invoke(nameof(IPersistableModel.Write), [_serializationOptionsParameter]), this ); @@ -527,7 +520,7 @@ internal MethodProvider BuildPersistableModelCreateMethodObjectDeclaration() var returnType = typeof(object); return new MethodProvider ( - new MethodSignature(nameof(IPersistableModel.Create), null, MethodSignatureModifiers.None, returnType, null, [_dataParameter, _serializationOptionsParameter], ExplicitInterface: PersistableModelObjectInterface), + new MethodSignature(nameof(IPersistableModel.Create), null, MethodSignatureModifiers.None, returnType, null, [_dataParameter, _serializationOptionsParameter], ExplicitInterface: _persistableModelObjectInterface), castToT.Invoke(nameof(IPersistableModel.Create), [_dataParameter, _serializationOptionsParameter]), this ); @@ -541,7 +534,7 @@ internal MethodProvider BuildJsonModelWriteCoreMethod() MethodSignatureModifiers modifiers = _isStruct ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (ShouldOverrideMethods) + if (_shouldOverrideMethods) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -563,7 +556,7 @@ internal MethodProvider BuildPersistableModelWriteCoreMethod() ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (ShouldOverrideMethods && !ShouldSkipSerializationMethodOverrides) + if (_shouldOverrideMethods && !_shouldSkipDerivedSerializationMethodOverrides) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -587,7 +580,7 @@ internal MethodProvider BuildPersistableModelCreateCoreMethod() ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (ShouldOverrideMethods && !ShouldSkipSerializationMethodOverrides) + if (_shouldOverrideMethods && !_shouldSkipDerivedSerializationMethodOverrides) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -635,7 +628,7 @@ internal MethodProvider BuildJsonModelCreateCoreMethod() ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (ShouldOverrideMethods && !ShouldSkipSerializationMethodOverrides) + if (_shouldOverrideMethods && !_shouldSkipDerivedSerializationMethodOverrides) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -807,7 +800,7 @@ internal MethodProvider BuildPersistableModelGetFormatFromOptionsObjectDeclarati // string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => ((IPersistableModel)this).GetFormatFromOptions(options); return new MethodProvider ( - new MethodSignature(nameof(IPersistableModel.GetFormatFromOptions), null, MethodSignatureModifiers.None, typeof(string), null, [_serializationOptionsParameter], ExplicitInterface: PersistableModelObjectInterface), + new MethodSignature(nameof(IPersistableModel.GetFormatFromOptions), null, MethodSignatureModifiers.None, typeof(string), null, [_serializationOptionsParameter], ExplicitInterface: _persistableModelObjectInterface), castToT.Invoke(nameof(IPersistableModel.GetFormatFromOptions), [_serializationOptionsParameter]), this ); @@ -1066,7 +1059,7 @@ private MethodBodyStatement[] BuildPersistableModelCreateCoreMethodBody() private MethodBodyStatement CallBaseJsonModelWriteCore(bool isDynamicModelWithNonDynamicBase) { // base.() - bool callBaseWriteMethod = ShouldOverrideMethods + bool callBaseWriteMethod = _shouldOverrideMethods && (_jsonPatchProperty is null || !isDynamicModelWithNonDynamicBase); return callBaseWriteMethod ? Base.Invoke(JsonModelWriteCoreMethodName, [_utf8JsonWriterParameter, _serializationOptionsParameter]).Terminate() diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs index 88cb97b16e7..696d654dd9c 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs @@ -51,6 +51,12 @@ protected override string BuildRelativeFilePath() return Path.Combine("src", "Generated", "Models", $"{Name}.Serialization.Multipart.cs"); } + protected override IReadOnlyList BuildHelperDependencyTypes() => _model.Properties.Any( + prop => prop.WireInfo != null && !prop.WireInfo.IsRequired && + (prop.Type is { IsCollection: true, IsReadOnlyMemory: false } || prop.Type.IsDictionary)) + ? [ScmCodeModelGenerator.Instance.TypeFactory.OptionalType] + : []; + protected override SuppressionStatement[] BuildDisabledFileWarnings() => [new SuppressionStatement(null, Literal(ScmModelProvider.FileBinaryContentDiagnosticId), ScmModelProvider.ScmEvaluationTypeSuppressionJustification)]; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs index f00d78571d0..ba6967e7d22 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs @@ -78,6 +78,44 @@ protected override FieldProvider[] BuildFields() return [.. pipelineMessage20xClassifiersFields]; } + protected override IReadOnlyList BuildHelperDependencyTypes() + { + var dependencies = new List { new ClientUriBuilderDefinition().Type }; + foreach (var serviceMethod in _inputClient.Methods) + { + foreach (var parameter in serviceMethod.Operation.Parameters) + { + if (IsGeneratedContentTypeMethodParameter(parameter) || + parameter is not InputHeaderParameter and not InputQueryParameter) + { + continue; + } + + var type = ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(parameter.Type); + if (type?.IsDictionary == true) + { + AddDependency(dependencies, ScmCodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType); + } + else if (type?.IsCollection == true) + { + AddDependency(dependencies, ScmCodeModelGenerator.Instance.TypeFactory.ListInitializationType); + } + } + } + + return dependencies; + } + + private static void AddDependency(List dependencies, CSharpType dependency) + { + if (!dependencies.Any(existing => + existing.Name == dependency.Name && + existing.Namespace == dependency.Namespace)) + { + dependencies.Add(dependency); + } + } + protected override ScmMethodProvider[] BuildMethods() { List methods = new List(); @@ -549,18 +587,6 @@ private static MethodBodyStatement BuildAppendQueryStatement( { if (paramType?.IsCollection != true) { - // A model-typed query parameter marked with `explode` must be expanded into one query - // entry per property (RFC 6570 form explode, e.g. `?field=status&value=active`) rather - // than serialized via the object's ToString (which previously produced the type name). - if (inputQueryParameter.Explode && inputQueryParameter.Type is InputModelType inputModel) - { - var explodeStatement = BuildExplodeModelQueryStatement(uri, inputModel, valueExpression); - if (explodeStatement != null) - { - return explodeStatement; - } - } - var toStringExpression = GetQueryParameterStringExpression(paramType, valueExpression, serializationFormat); return uri.AppendQuery(Literal(inputQueryParameter.SerializedName), toStringExpression, true).Terminate(); } @@ -617,70 +643,6 @@ private static MethodBodyStatement BuildAppendQueryStatement( } } - /// - /// Builds the statements for a model-typed query parameter that uses form-style `explode`. - /// Each (simple) property of the model is emitted as its own query entry using the property's - /// wire name (RFC 6570 form explode, e.g. ?field=status&value=active). - /// Returns null when the model contains a property that is not a simple scalar/enum - /// (e.g. a nested object or a collection), in which case the caller falls back to the default - /// handling. Nested/complex expansion is tracked separately (see issue #11123). - /// - private static MethodBodyStatement? BuildExplodeModelQueryStatement( - ScopedApi uri, - InputModelType inputModel, - ValueExpression valueExpression) - { - var modelProvider = ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(inputModel); - if (modelProvider is null) - { - return null; - } - - var properties = modelProvider.CanonicalView.Properties; - if (properties.Count == 0) - { - return null; - } - - // Only expand when every property is a simple scalar or enum. Nested objects and - // collections are not defined by RFC 6570 form explode and require a separate design - // decision, so we fall back to the default handling for those. - foreach (var property in properties) - { - if (property.WireInfo is null || - property.Type.IsCollection || - (!property.Type.IsFrameworkType && !property.Type.IsEnum)) - { - return null; - } - } - - var statements = new List(); - foreach (var property in properties) - { - var propertyAccess = valueExpression.Property(property.Name); - var propertyType = property.Type; - - ValueExpression convertedValue = propertyType.IsEnum - ? propertyType.ToSerial(propertyAccess).ConvertToString() - : GetQueryParameterStringExpression(propertyType, propertyAccess, property.SerializationFormat); - - MethodBodyStatement appendStatement = - uri.AppendQuery(Literal(property.WireInfo!.SerializedName), convertedValue, true).Terminate(); - - if (!property.WireInfo.IsRequired || - propertyType.IsNullable || - (propertyType is { IsValueType: false, IsFrameworkType: true } && propertyType.FrameworkType != typeof(string))) - { - appendStatement = BuildQueryOrHeaderOrPathParameterNullCheck(propertyType, propertyAccess, appendStatement); - } - - statements.Add(appendStatement); - } - - return statements; - } - private static IfStatement BuildQueryOrHeaderOrPathParameterNullCheck( CSharpType? parameterType, ValueExpression valueExpression, @@ -919,7 +881,9 @@ private static void AppendLiteralSegment(ScopedApi uri, string literal, List paramMap, InputOperation operation, InputParameter inputParam, out CSharpType? type, out SerializationFormat? serializationFormat, out ValueExpression? valueExpression) { - type = ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(inputParam.Type); + type = IsGeneratedContentTypeMethodParameter(inputParam) + ? null + : ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(inputParam.Type); serializationFormat = null; if (inputParam.IsApiVersion && ClientProvider.IsMultiServiceClient) @@ -1208,7 +1172,10 @@ internal static List GetMethodParameters( // when one was already published. UpdateParameterNameWithBackCompat(inputParam, inputParam.Name, client.BackCompatProvider, serviceMethod); - ParameterProvider? parameter = ScmCodeModelGenerator.Instance.TypeFactory.CreateParameter(inputParam)?.ToPublicInputParameter(); + ParameterProvider? parameter = IsGeneratedContentTypeMethodParameter(inputParam) && + methodType is ScmMethodKind.Protocol or ScmMethodKind.CreateRequest + ? CreateContentTypeParameter(inputParam) + : ScmCodeModelGenerator.Instance.TypeFactory.CreateParameter(inputParam)?.ToPublicInputParameter(); if (parameter is null) { continue; @@ -1249,7 +1216,7 @@ internal static List GetMethodParameters( break; case ParameterLocation.Query: case ParameterLocation.Header: - if (inputParam is InputHeaderParameter { IsContentType: true } + if (IsGeneratedContentTypeMethodParameter(inputParam) && !HasContentTypeBeforeBodyInLastContract(serviceMethod.Name, client.BackCompatProvider)) { sortedParams.Add(contentType++, parameter); @@ -1292,12 +1259,25 @@ internal static List GetMethodParameters( return [.. sortedParams.Values]; } + private static ParameterProvider CreateContentTypeParameter(InputParameter inputParam) + { + var type = new CSharpType(typeof(string), isNullable: !inputParam.IsRequired); + return new ParameterProvider( + inputParam.Name, + DocHelpers.GetFormattableDescription(inputParam.Summary, inputParam.Doc) ?? FormattableStringHelpers.Empty, + type, + defaultValue: inputParam.IsRequired ? null : Default, + location: ParameterLocation.Header, + wireInfo: new WireInformation(SerializationFormat.Default, inputParam.SerializedName), + validation: inputParam.IsRequired ? ParameterValidationType.AssertNotNullOrEmpty : ParameterValidationType.None, + inputParameter: inputParam); + } + private static bool HasLiteralContentTypeHeader(InputOperation operation) { foreach (var p in operation.Parameters) { - if (p is InputHeaderParameter { IsContentType: true } header - && header.Type is InputLiteralType) + if (p is InputHeaderParameter { IsContentType: true } && p.Type is InputLiteralType) { return true; } @@ -1305,6 +1285,10 @@ private static bool HasLiteralContentTypeHeader(InputOperation operation) return false; } + private static bool IsGeneratedContentTypeMethodParameter(InputParameter parameter) => + parameter is InputMethodParameter { Location: InputRequestLocation.Header } && + string.Equals(parameter.SerializedName, "Content-Type", StringComparison.OrdinalIgnoreCase); + /// /// Checks if the last contract view contains a method matching the given name where /// a "contentType" parameter appears before the body ("content") parameter. diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/SerializationFormatDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/SerializationFormatDefinition.cs index af294640060..e902afd7c79 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/SerializationFormatDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/SerializationFormatDefinition.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.Collections.Generic; using System.IO; using System.Linq; @@ -45,6 +46,7 @@ protected override TypeSignatureModifiers BuildDeclarationModifiers() protected override string BuildRelativeFilePath() => Path.Combine("src", "Generated", "Internal", $"{Name}.cs"); protected override string BuildName() => "SerializationFormat"; + protected override FormattableString BuildDescription() => $"The serialization format."; protected override TypeProvider[] BuildSerializationProviders() => []; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Snippets/HttpRequestApiSnippets.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Snippets/HttpRequestApiSnippets.cs index 588a2094b12..fdd8f72fba0 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Snippets/HttpRequestApiSnippets.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Snippets/HttpRequestApiSnippets.cs @@ -26,7 +26,9 @@ public static MethodBodyStatement SetContent(this ScopedApi pip public static MethodBodyStatement SetHeaderDelimited(this HttpRequestApi pipelineRequest, string name, ValueExpression value, ValueExpression delimiter, ValueExpression? format = null) { ValueExpression[] parameters = format != null ? [Literal(name), value, delimiter, format] : [Literal(name), value, delimiter]; - return pipelineRequest.Property(nameof(PipelineRequest.Headers)).Invoke("SetDelimited", parameters).Terminate(); + return pipelineRequest.Property(nameof(PipelineRequest.Headers)) + .Invoke("SetDelimited", parameters, typeArguments: null, callAsAsync: false, extensionType: new PipelineRequestHeadersExtensionsDefinition().Type) + .Terminate(); } } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs new file mode 100644 index 00000000000..b27943d96ce --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs @@ -0,0 +1,660 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.TypeSpec.Generator.Input; +using Microsoft.TypeSpec.Generator.Tests.Common; +using NUnit.Framework; + +namespace Microsoft.TypeSpec.Generator.ClientModel.Tests.PostProcessing +{ + public class ClientBodyDependencyPostProcessingTests + { + [Test] + public async Task OperationBodyParameterModelDoesNotBecomePublic() + { + var requestModel = InputFactory.Model("RequestBody"); + var parameter = InputFactory.BodyParameter("body", requestModel, isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Create", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertInternalModels([requestModel], [client], ["RequestBody"]); + } + + [Test] + public async Task OperationResponseBodyModelRemainsPublicAsRootOutputModel() + { + var responseModel = InputFactory.Model("ResponseBody"); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertPublicModels([responseModel], [client], ["ResponseBody"]); + } + + [Test] + public async Task OperationResponseBodyModelIsRemovedWhenNotOtherwiseReferenced() + { + var metadataOnlyModel = InputFactory.Model("MetadataOnlyResponse"); + var operation = InputFactory.Operation( + "Get", + responses: [ + InputFactory.OperationResponse(bodytype: InputPrimitiveType.String), + new InputOperationResponse([202], metadataOnlyModel, [], isErrorResponse: false, ["application/json"]) + ]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [metadataOnlyModel], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponse.cs"), + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponse.Serialization.cs") + ]); + } + + [Test] + public async Task InternalAdditionalRootModelIsRemovedWhenNotOtherwiseReferenced() + { + var metadataOnlyModel = InputFactory.Model("MetadataOnlyResponse", access: "internal"); + var operation = InputFactory.Operation( + "Get", + responses: [ + InputFactory.OperationResponse(bodytype: InputPrimitiveType.String), + new InputOperationResponse([202], metadataOnlyModel, [], isErrorResponse: false, ["application/json"]) + ]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [metadataOnlyModel], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponse.cs"), + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponse.Serialization.cs") + ], + configureGenerator: () => + { + var provider = CodeModelGenerator.Instance.OutputLibrary.TypeProviders.Single(provider => provider.Name == "MetadataOnlyResponse"); + CodeModelGenerator.Instance.AddTypeToKeep(provider); + }); + } + + [Test] + public async Task AdditionalRootEnumIsRemovedWhenNotOtherwiseReferenced() + { + var metadataOnlyEnum = InputFactory.StringEnum( + "MetadataOnlyResponseKind", + [("Accepted", "accepted")]); + var operation = InputFactory.Operation( + "Get", + responses: [ + InputFactory.OperationResponse(bodytype: InputPrimitiveType.String), + new InputOperationResponse([202], metadataOnlyEnum, [], isErrorResponse: false, ["application/json"]) + ]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [metadataOnlyEnum], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponseKind.cs"), + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponseKind.Serialization.cs") + ]); + } + + [Test] + public async Task ContentTypeHeaderEnumIsRemovedWhenNotOtherwiseReferenced() + { + var contentTypeEnum = InputFactory.StringEnum( + "UpdateSnapshotRequestContentType", + [ + ("ApplicationMergePatchJson", "application/merge-patch+json"), + ("ApplicationJson", "application/json") + ]); + var contentTypeParameter = InputFactory.MethodParameter( + "contentType", + InputFactory.Union([contentTypeEnum], "contentType"), + isRequired: true, + location: InputRequestLocation.Header, + serializedName: "Content-Type"); + var operation = InputFactory.Operation( + "UpdateSnapshot", + parameters: [contentTypeParameter], + httpMethod: "PATCH", + generateConvenienceMethod: false); + var method = InputFactory.BasicServiceMethod("UpdateSnapshot", operation); + var client = InputFactory.Client("ConfigurationClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [contentTypeEnum], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "UpdateSnapshotRequestContentType.cs"), + Path.Combine("src", "Generated", "Models", "UpdateSnapshotRequestContentType.Serialization.cs") + ], + configureGenerator: () => + CodeModelGenerator.Instance.TypeFactory.CreateCSharpType(InputFactory.Union([contentTypeEnum], "contentType"))); + } + + [Test] + public async Task ContentTypeHeaderEnumReferencedByCustomSuppressionIsKept() + { + var contentTypeEnum = InputFactory.StringEnum( + "PutKeyValueRequestContentType", + [("ApplicationJson", "application/json")], + isExtensible: true); + var contentTypeParameter = InputFactory.HeaderParameter( + "contentType", + InputFactory.Union([contentTypeEnum], "contentType"), + isRequired: true, + isContentType: true, + serializedName: "Content-Type"); + var operation = InputFactory.Operation( + "SetConfigurationSettingInternal", + parameters: [contentTypeParameter], + httpMethod: "PUT"); + var method = InputFactory.BasicServiceMethod("SetConfigurationSettingInternal", operation); + var client = InputFactory.Client("ConfigurationClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [contentTypeEnum], + models: [], + clients: [client], + customFiles: [ + (Path.Combine("src", "PutKeyValueRequestContentType.cs"), """ + namespace Sample.Models; + + internal readonly partial struct PutKeyValueRequestContentType + { + public static PutKeyValueRequestContentType ApplicationJson { get; } = new PutKeyValueRequestContentType("application/json"); + } + """) + ], + expectedFiles: [ + Path.Combine("src", "Generated", "Models", "PutKeyValueRequestContentType.cs") + ]); + } + + [Test] + public async Task NestedBodyModelGraphDoesNotBecomePublic() + { + var nestedModel = InputFactory.Model("NestedToolParameter"); + var toolModel = InputFactory.Model( + "ToolConfig", + properties: [InputFactory.Property("Parameter", nestedModel)]); + var parameter = InputFactory.BodyParameter("tool", toolModel, isRequired: true); + var operation = InputFactory.Operation("Configure", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Configure", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertInternalModels([toolModel, nestedModel], [client], ["ToolConfig", "NestedToolParameter"]); + } + + [Test] + public async Task NonDiscriminatorDerivedBodyModelDoesNotBecomePublicFromPublicBase() + { + var baseTool = InputFactory.Model("BaseTool"); + var concreteTool = InputFactory.Model( + "ConcreteTool", + properties: [InputFactory.Property("Name", InputPrimitiveType.String)], + baseModel: baseTool); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: baseTool)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(baseTool, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertMixedModels( + [baseTool, concreteTool], + [client], + publicModelNames: ["BaseTool"], + internalModelNames: ["ConcreteTool"]); + } + + [Test] + public async Task PublicModelSignatureDependencyIsPromotedToPublic() + { + var internalDependency = InputFactory.Model("InternalDependency", access: "internal"); + var responseModel = InputFactory.Model( + "ResponseBody", + properties: [InputFactory.Property("Dependency", internalDependency)]); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(responseModel, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertPublicModels([responseModel, internalDependency], [client], ["ResponseBody", "InternalDependency"]); + } + + [Test] + public async Task AzureClientPublicMethodSignatureReferencesStayPublic() + { + var signatureModel = InputFactory.Model("SignatureModel", @namespace: "Azure.Sample.Models"); + var methodParameter = InputFactory.MethodParameter("signature", signatureModel, isRequired: true); + var operation = InputFactory.Operation( + "Create", + parameters: [InputFactory.BodyParameter("signature", signatureModel, isRequired: true)], + httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Create", operation, parameters: [methodParameter]); + var client = InputFactory.Client("SampleClient", clientNamespace: "Azure.Sample", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [signatureModel], + clients: [client], + customFiles: [], + expectedFiles: [], + publicModelNames: ["SignatureModel"], + packageName: "Azure.Sample"); + } + + [Test] + public async Task BasePreservedDerivedModelTraversesTransitiveDependencies() + { + var transitiveDependency = InputFactory.Model("TransitiveDependency"); + var dependency = InputFactory.Model( + "DerivedDependency", + properties: [InputFactory.Property("Transitive", transitiveDependency)]); + var baseModel = InputFactory.Model("BaseResult"); + var derivedModel = InputFactory.Model( + "DerivedResult", + properties: [InputFactory.Property("Dependency", dependency)], + baseModel: baseModel); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: baseModel)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(baseModel, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [baseModel, derivedModel, dependency, transitiveDependency], + clients: [client], + customFiles: [], + expectedFiles: [], + publicModelNames: ["BaseResult"], + internalModelNames: ["DerivedResult", "DerivedDependency", "TransitiveDependency"]); + } + + [Test] + public async Task PublicCustomCodeArraySignatureReferencesStayPublic() + { + var generatedModel = InputFactory.Model("GeneratedModel"); + + await GenerateAndAssertFiles( + enums: [], + models: [generatedModel], + clients: [], + customFiles: [ + (Path.Combine("src", "PublicCustomApi.cs"), """ + using Sample.Models; + + namespace Sample; + + public partial class PublicCustomApi + { + public GeneratedModel[] Items { get; } = System.Array.Empty(); + } + """) + ], + expectedFiles: [], + publicModelNames: ["GeneratedModel"]); + } + + [Test] + public async Task GeneratedRequestHeaderSetDelimitedReferenceKeepsExtensions() + { + var header = InputFactory.HeaderParameter("x-ms-custom", InputFactory.Array(InputPrimitiveType.String), isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [header]); + var method = InputFactory.BasicServiceMethod("Create", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [Path.Combine("src", "Generated", "Internal", "PipelineRequestHeadersExtensions.cs")]); + } + + [Test] + public async Task BinaryDataBodyParameterDoesNotKeepBinaryContentHelpers() + { + var parameter = InputFactory.BodyParameter( + "content", + InputPrimitiveType.Base64, + isRequired: true, + contentTypes: ["application/octet-stream"], + defaultContentType: "application/octet-stream"); + var operation = InputFactory.Operation("Upload", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod( + "Upload", + operation, + parameters: [InputFactory.MethodParameter("content", InputPrimitiveType.Base64, isRequired: true)]); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Internal", "BinaryContentHelper.cs"), + Path.Combine("src", "Generated", "Internal", "Utf8JsonBinaryContent.cs") + ]); + } + + [Test] + public async Task CollectionBodyParameterKeepsBinaryContentHelpers() + { + var parameter = InputFactory.BodyParameter("items", InputFactory.Array(InputPrimitiveType.String), isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod( + "Create", + operation, + parameters: [InputFactory.MethodParameter("items", InputFactory.Array(InputPrimitiveType.String), isRequired: true)]); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [ + Path.Combine("src", "Generated", "Internal", "BinaryContentHelper.cs"), + Path.Combine("src", "Generated", "Internal", "Utf8JsonBinaryContent.cs") + ]); + } + + [Test] + public async Task CustomOnlyRequestHeaderSetDelimitedReferenceKeepsExtensions() + { + await GenerateAndAssertFiles( + enums: [], + models: [], + clients: [], + customFiles: [ + (Path.Combine("src", "CustomHeaders.cs"), """ + using System.ClientModel.Primitives; + + namespace Sample; + + public static class CustomHeaders + { + public static void Add(PipelineRequestHeaders headers, string[] values) + => headers.SetDelimited("x-ms-custom", values, ","); + } + """) + ], + expectedFiles: [Path.Combine("src", "Generated", "Internal", "PipelineRequestHeadersExtensions.cs")]); + } + + [Test] + public async Task CustomizedEnumSerializationProviderIsKeptWhenModelSerializationUsesEnum() + { + var statusEnum = InputFactory.StringEnum( + "Status", + [("Succeeded", "succeeded"), ("Failed", "failed")], + clientNamespace: "Sample"); + var resultModel = InputFactory.Model( + "OperationResult", + properties: [InputFactory.Property("Status", statusEnum, isRequired: true)], + @namespace: "Sample"); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: resultModel)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(resultModel, [])); + var client = InputFactory.Client("TestClient", methods: [method], clientNamespace: "Sample"); + + await GenerateAndAssertFiles( + enums: [statusEnum], + models: [resultModel], + clients: [client], + customFiles: [ + (Path.Combine("src", "Custom", "Status.cs"), """ + namespace Sample; + + [CodeGenType("Status")] + public enum Status + { + Succeeded, + Failed + } + """) + ], + expectedFiles: [Path.Combine("src", "Generated", "Models", "Status.Serialization.cs")]); + } + + [Test] + public async Task CustomModelFactoryPartialDoesNotKeepBodyOnlyModelPublic() + { + var requestModel = InputFactory.Model("RequestBody"); + var parameter = InputFactory.BodyParameter("body", requestModel, isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Create", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [requestModel], + clients: [client], + customFiles: [ + (Path.Combine("src", "SampleModelFactory.cs"), """ + namespace Sample; + + [Microsoft.TypeSpec.Generator.Customizations.CodeGenType("SampleModelFactory")] + public static partial class SampleModelFactory + { + } + """) + ], + expectedFiles: [], + internalModelNames: ["RequestBody"]); + } + + [Test] + public async Task InternalCustomClientPartialOverridesLastContractPublicClient() + { + var responseModel = InputFactory.Model("CompactResource"); + var operation = InputFactory.Operation("Compact", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod("Compact", operation, response: InputFactory.ServiceMethodResponse(responseModel, [])); + var client = InputFactory.Client("Responses", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [responseModel], + clients: [client], + customFiles: [ + (Path.Combine("src", "Generated", "Responses.cs"), """ + namespace Sample; + + public partial class Responses + { + } + """), + (Path.Combine("src", "Custom", "Internal", "Responses.cs"), """ + namespace Sample; + + internal partial class Responses + { + } + """) + ], + expectedFiles: [], + internalModelNames: ["CompactResource"], + internalClientNames: ["Responses"]); + } + + private static async Task GenerateAndAssertInternalModels( + InputModelType[] models, + InputClient[] clients, + string[] modelNames) + => await GenerateAndAssertModels(models, clients, modelNames, shouldBePublic: false); + + private static async Task GenerateAndAssertPublicModels( + InputModelType[] models, + InputClient[] clients, + string[] modelNames) + => await GenerateAndAssertModels(models, clients, modelNames, shouldBePublic: true); + + private static async Task GenerateAndAssertMixedModels( + InputModelType[] models, + InputClient[] clients, + string[] publicModelNames, + string[] internalModelNames) + => await GenerateAndAssertModels(models, clients, publicModelNames, internalModelNames); + + private static async Task GenerateAndAssertModels( + InputModelType[] models, + InputClient[] clients, + string[] modelNames, + bool shouldBePublic) + => await GenerateAndAssertModels( + models, + clients, + shouldBePublic ? modelNames : [], + shouldBePublic ? [] : modelNames); + + private static async Task GenerateAndAssertModels( + InputModelType[] models, + InputClient[] clients, + string[] publicModelNames, + string[] internalModelNames) + { + await GenerateAndAssertFiles( + enums: [], + models: models, + clients: clients, + customFiles: [], + publicModelNames: publicModelNames, + internalModelNames: internalModelNames, + expectedFiles: []); + } + + private static async Task GenerateAndAssertFiles( + InputEnumType[] enums, + InputModelType[] models, + InputClient[] clients, + (string Path, string Content)[] customFiles, + string[] expectedFiles, + string[] unexpectedFiles = null!, + string[] publicModelNames = null!, + string[] internalModelNames = null!, + string[] internalClientNames = null!, + string packageName = "Sample", + Action? configureGenerator = null) + { + publicModelNames ??= []; + internalModelNames ??= []; + internalClientNames ??= []; + unexpectedFiles ??= []; + + var outputPath = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + Directory.CreateDirectory(outputPath); + try + { + foreach (var customFile in customFiles) + { + var customPath = Path.Combine(outputPath, customFile.Path); + Directory.CreateDirectory(Path.GetDirectoryName(customPath)!); + File.WriteAllText(customPath, customFile.Content); + } + + await MockHelpers.LoadMockGeneratorAsync( + inputEnums: () => enums, + inputModels: () => models, + clients: () => clients, + configuration: $$"""{ "package-name": "{{packageName}}", "disable-xml-docs": true }""", + outputPath: outputPath); + configureGenerator?.Invoke(); + + await new CSharpGen().ExecuteAsync(); + + foreach (var modelName in publicModelNames) + { + var modelPath = Path.Combine(outputPath, "src", "Generated", "Models", $"{modelName}.cs"); + Assert.IsTrue(File.Exists(modelPath), $"Expected generated model file '{modelPath}'."); + var text = File.ReadAllText(modelPath); + StringAssert.Contains($"public partial class {modelName}", text, $"{modelName} should be public."); + } + + foreach (var modelName in internalModelNames) + { + var modelPath = Path.Combine(outputPath, "src", "Generated", "Models", $"{modelName}.cs"); + Assert.IsTrue(File.Exists(modelPath), $"Expected generated model file '{modelPath}'."); + var text = File.ReadAllText(modelPath); + StringAssert.Contains($"internal partial class {modelName}", text, $"{modelName} should be internal."); + StringAssert.DoesNotContain($"public partial class {modelName}", text, $"{modelName} should not be public."); + } + + foreach (var clientName in internalClientNames) + { + var clientPath = Path.Combine(outputPath, "src", "Generated", $"{clientName}.cs"); + Assert.IsTrue(File.Exists(clientPath), $"Expected generated client file '{clientPath}'."); + var text = File.ReadAllText(clientPath); + StringAssert.Contains($"internal partial class {clientName}", text, $"{clientName} should be internal."); + StringAssert.DoesNotContain($"public partial class {clientName}", text, $"{clientName} should not be public."); + } + + var modelFactoryPath = Path.Combine(outputPath, "src", "Generated", "SampleModelFactory.cs"); + if (File.Exists(modelFactoryPath)) + { + var modelFactoryText = File.ReadAllText(modelFactoryPath); + foreach (var modelName in publicModelNames) + { + StringAssert.Contains($" {modelName}(", modelFactoryText, $"Model factory method for {modelName} should be generated."); + } + + foreach (var modelName in internalModelNames) + { + StringAssert.DoesNotContain($" {modelName}(", modelFactoryText, $"Model factory method for {modelName} should not be generated."); + } + } + + foreach (var expectedFile in expectedFiles) + { + var filePath = Path.Combine(outputPath, expectedFile); + Assert.IsTrue(File.Exists(filePath), $"Expected generated file '{filePath}'."); + } + + foreach (var unexpectedFile in unexpectedFiles) + { + var filePath = Path.Combine(outputPath, unexpectedFile); + Assert.IsFalse(File.Exists(filePath), $"Did not expect generated file '{filePath}'."); + } + } + finally + { + if (Directory.Exists(outputPath)) + { + Directory.Delete(outputPath, recursive: true); + } + } + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs index fcc90582416..503ed1ad68f 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs @@ -121,31 +121,6 @@ public void JsonModelWriteCore_IsOverride_WhenBaseIsRegularModel() "JsonModelWriteCore should be 'override' with regular base too"); } - [Test] - public void JsonModelWriteCore_IsOverride_WhenBaseProviderIsResolvedAfterSerialization() - { - var baseInputModel = InputFactory.Model("Resource"); - var derivedInputModel = InputFactory.Model("TrackedResource", properties: [InputFactory.Property("Location", InputPrimitiveType.String)]); - MockHelpers.LoadMockGenerator(inputModels: () => [baseInputModel, derivedInputModel]); - - var derived = new DelayedBaseModelProvider(derivedInputModel); - var serialization = new MrwSerializationTypeDefinition(derivedInputModel, derived); - - // The serialization provider can be constructed before later visitors/customization - // resolution make the base model provider available. - derived.BaseModel = new SystemObjectModelProvider(new CSharpType(typeof(object)), baseInputModel); - - var method = serialization.BuildJsonModelWriteCoreMethod(); - - Assert.AreEqual(derived.BaseModel.Type, derived.Type.BaseType, - "The generated model type should inherit the base resolved after serialization construction."); - Assert.AreEqual(derived.BaseModel.Type, serialization.Type.BaseType, - "The serialization type should inherit the same resolved base."); - Assert.IsTrue(method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Override), - "JsonModelWriteCore should evaluate BaseModelProvider when the method is built, not when serialization is constructed"); - Assert.IsFalse(method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Virtual)); - } - // ------------------------------------------------------------------- // PersistableModelWriteCore: 'virtual' with system base, 'override' with regular // (the framework base already implements this; derived model re-introduces it) @@ -353,14 +328,5 @@ FakeMrwBase IPersistableModel.Create(BinaryData data, ModelReaderWr string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; } - - private class DelayedBaseModelProvider(InputModelType inputModel) : ModelProvider(inputModel) - { - public ModelProvider? BaseModel { get; set; } - - protected override ModelProvider? BuildBaseModelProvider() => BaseModel; - - protected override CSharpType? BuildBaseType() => BaseModel?.Type; - } } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs index e8bcb2fa0fb..7e637e07363 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs @@ -738,38 +738,6 @@ public void ValidateGetResponseClassifiersThrowsWhenNoSuccess() Assert.Fail("Expected Exception to be thrown."); } - [Test] - public void TestBuildCreateRequestMethodWithExplodedModelQueryParameter() - { - var filterModel = InputFactory.Model( - "filterOptions", - properties: - [ - InputFactory.Property("field", InputPrimitiveType.String, isRequired: true), - InputFactory.Property("value", InputPrimitiveType.String, isRequired: true), - ]); - var operation = InputFactory.Operation( - "sampleOp", - parameters: [InputFactory.QueryParameter("filter", filterModel, isRequired: true, explode: true)]); - var client = InputFactory.Client( - "TestClient", - methods: [InputFactory.BasicServiceMethod("Test", operation)]); - var clientProvider = new ClientProvider(client); - var restClientProvider = new MockClientProvider(client, clientProvider); - - var method = restClientProvider.Methods.FirstOrDefault(m => m.Signature.Name == "CreateSampleOpRequest"); - Assert.IsNotNull(method); - var body = method!.BodyStatements!.ToDisplayString(); - - // A model-typed query parameter with `explode` is expanded into one query entry per - // property (RFC 6570 form explode) using each property's wire name, instead of serializing - // the whole object via ConvertToString (which produced the type name). - Assert.IsTrue(body.Contains("uri.AppendQuery(\"field\", filter.Field, true);"), body); - Assert.IsTrue(body.Contains("uri.AppendQuery(\"value\", filter.Value, true);"), body); - Assert.IsFalse(body.Contains("AppendQuery(\"filter\""), body); - Assert.IsFalse(body.Contains("ConvertToString(filter)"), body); - } - [Test] public void TestBuildCreateRequestMethodWithQueryParameters() { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs index 9148f659e43..6c6e743ad89 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs @@ -34,7 +34,8 @@ public static async Task> LoadMockGeneratorAsync( Func>? apiVersions = null, string? configuration = null, Func? createCSharpTypeCore = null, - Func? createCSharpTypeCoreFallback = null) + Func? createCSharpTypeCoreFallback = null, + string? outputPath = null) { var mockGenerator = LoadMockGenerator( inputLiterals: inputLiterals, @@ -44,13 +45,13 @@ public static async Task> LoadMockGeneratorAsync( apiVersions: apiVersions, configuration: configuration, createCSharpTypeCore: createCSharpTypeCore, - createCSharpTypeCoreFallback: createCSharpTypeCoreFallback); + createCSharpTypeCoreFallback: createCSharpTypeCoreFallback, + outputPath: outputPath); var compilationResult = compilation == null ? null : await compilation(); var lastContractCompilationResult = lastContractCompilation == null ? null : await lastContractCompilation(); - var sourceInputModel = new Mock(() => new SourceInputModel(compilationResult, lastContractCompilationResult)) { CallBase = true }; - mockGenerator.Setup(p => p.SourceInputModel).Returns(sourceInputModel.Object); + mockGenerator.SetupProperty(p => p.SourceInputModel, new SourceInputModel(compilationResult, lastContractCompilationResult)); return mockGenerator; } @@ -76,7 +77,8 @@ public static Mock LoadMockGenerator( Func? createOutputLibrary = null, bool includeXmlDocs = false, Func? createCSharpTypeCoreFallback = null, - Func? createModelCore = null) + Func? createModelCore = null, + string? outputPath = null) { IReadOnlyList inputNsApiVersions = apiVersions?.Invoke() ?? []; IReadOnlyList inputNsLiterals = inputLiterals?.Invoke() ?? []; @@ -150,7 +152,7 @@ public static Mock LoadMockGenerator( { configuration = "{\"disable-xml-docs\": false, \"package-name\": \"Sample.Namespace\"}"; } - object?[] parameters = [_configFilePath, configuration]; + object?[] parameters = [outputPath ?? _configFilePath, configuration]; var config = loadMethod?.Invoke(null, parameters); var mockGeneratorContext = new Mock(config!); var mockGeneratorInstance = new Mock(mockGeneratorContext.Object) { CallBase = true }; @@ -186,8 +188,7 @@ public static Mock LoadMockGenerator( mockGeneratorInstance.Setup(p => p.OutputLibrary).Returns(createOutputLibrary); } - var sourceInputModel = new Mock(() => new SourceInputModel(null, null)) { CallBase = true }; - mockGeneratorInstance.Setup(p => p.SourceInputModel).Returns(sourceInputModel.Object); + mockGeneratorInstance.SetupProperty(p => p.SourceInputModel, new SourceInputModel(null, null)); codeModelInstance!.SetValue(null, mockGeneratorInstance.Object); clientModelInstance!.SetValue(null, mockGeneratorInstance.Object); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/src/InputTypes/InputModelTypeUsage.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/src/InputTypes/InputModelTypeUsage.cs index 52d9c8c0bc8..f96ecad090e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/src/InputTypes/InputModelTypeUsage.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/src/InputTypes/InputModelTypeUsage.cs @@ -22,6 +22,5 @@ public enum InputModelTypeUsage LroInitial = 2048, LroPolling = 4096, LroFinalEnvelope = 8192, - External = 16384, } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/test/TypeSpecInputConverterTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/test/TypeSpecInputConverterTests.cs index 7e5ec1c3c84..9328dc43571 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/test/TypeSpecInputConverterTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/test/TypeSpecInputConverterTests.cs @@ -564,42 +564,6 @@ public void DeserializeModelWithExternalMetadata() Assert.AreEqual("8.0.0", model.External.MinVersion); } - [Test] - public void DeserializeModelWithExternalUsagePreservesInputAndOutput() - { - // TCGC emits the External usage flag (UsageFlags.External) for models that are also - // referenced by external types. The C# InputModelTypeUsage enum must recognize it so - // that Enum.TryParse does not fail on the unknown token and collapse the whole usage to - // None, which would strip Input/Output and make every property get-only. - var json = @"{ - ""$id"": ""1"", - ""kind"": ""model"", - ""name"": ""TestModel"", - ""namespace"": ""Test.Models"", - ""crossLanguageDefinitionId"": ""Test.Models.TestModel"", - ""usage"": ""Input,Output,External"", - ""properties"": [] - }"; - - var referenceHandler = new TypeSpecReferenceHandler(); - var options = new JsonSerializerOptions - { - AllowTrailingCommas = true, - Converters = - { - new InputTypeConverter(referenceHandler), - new InputModelTypeConverter(referenceHandler), - new InputExternalTypeMetadataConverter() - } - }; - - var model = JsonSerializer.Deserialize(json, options); - Assert.IsNotNull(model); - Assert.IsTrue(model!.Usage.HasFlag(InputModelTypeUsage.Input), "Model should retain Input usage flag"); - Assert.IsTrue(model.Usage.HasFlag(InputModelTypeUsage.Output), "Model should retain Output usage flag"); - Assert.IsTrue(model.Usage.HasFlag(InputModelTypeUsage.External), "Model should have External usage flag"); - } - [Test] public void DeserializeArrayWithExternalMetadata() { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs index c013817a72e..35914c1f8d4 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs @@ -27,12 +27,13 @@ public async Task ExecuteAsync() { CodeModelGenerator.Instance.Emitter.Info("Starting code generation"); CodeModelGenerator.Instance.Stopwatch.Start(); + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); var outputPath = CodeModelGenerator.Instance.Configuration.OutputDirectory; var generatedSourceOutputPath = CodeModelGenerator.Instance.Configuration.ProjectGeneratedDirectory; - // Resolve PackageReference items from the .csproj so custom code referencing - // external NuGet types (e.g., Azure.Storage.Common) compiles correctly. + // Resolve PackageReference items from the .csproj so custom code referencing external + // NuGet types compiles correctly. await GeneratedCodeWorkspace.AddPackageReferencesFromProject(); // Pre-walk the input library and resolve any external types that point at NuGet packages. @@ -90,12 +91,33 @@ await GeneratedCodeWorkspace.LoadBaselineContract(), { // Ensure back-compatibility processing is done after all visitors have run outputType.ProcessTypeForBackCompatibility(); + } + + generatedCodeWorkspace.ApplyPreWriteAccessibility(output.TypeProviders); + generatedCodeWorkspace.AnalyzeProviderReferenceMap(output.TypeProviders); + + foreach (var outputType in output.TypeProviders) + { + if (!ProviderReferenceMapAnalyzer.ShouldWriteProvider(outputType)) + { + continue; + } + + if (outputType is ModelFactoryProvider && outputType.Methods.Count == 0) + { + continue; + } var writer = CodeModelGenerator.Instance.GetWriter(outputType); generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); foreach (var serialization in outputType.SerializationProviders) { + if (!ProviderReferenceMapAnalyzer.ShouldWriteProvider(serialization)) + { + continue; + } + writer = CodeModelGenerator.Instance.GetWriter(serialization); generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); } @@ -104,6 +126,8 @@ await GeneratedCodeWorkspace.LoadBaselineContract(), // Add all the generated files to the workspace await Task.WhenAll(generateFilesTasks); + ProviderReferenceMapAnalyzer.RestorePreWriteModelFactoryMethods(); + LoggingHelpers.LogElapsedTime("All generated types have been written into memory"); // Delete any old generated files @@ -112,14 +136,22 @@ await GeneratedCodeWorkspace.LoadBaselineContract(), LoggingHelpers.LogElapsedTime("All old generated files have been deleted"); await generatedCodeWorkspace.PostProcessAsync(); + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); - // Write the generated files to the output directory + var generatedFiles = new List<(string Name, string Text)>(); await foreach (var file in generatedCodeWorkspace.GetGeneratedFilesAsync()) { if (string.IsNullOrEmpty(file.Text)) { continue; } + + generatedFiles.Add((file.Name, file.Text)); + } + + // Write the generated files to the output directory + foreach (var file in generatedFiles) + { var filename = Path.Combine(outputPath, file.Name); CodeModelGenerator.Instance.Emitter.Info($"Writing {Path.GetFullPath(filename)}"); Directory.CreateDirectory(Path.GetDirectoryName(filename)!); @@ -177,9 +209,10 @@ private static void DeleteDirectory(string path, string[] filesToKeep) return; } + var fileNamesToKeep = filesToKeep.ToHashSet(StringComparer.Ordinal); foreach (var file in directoryInfo.GetFiles("*", SearchOption.AllDirectories)) { - if (!filesToKeep.Contains(file.Name)) + if (!fileNamesToKeep.Contains(file.Name)) { file.Delete(); } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs index e9e016c3554..6df7ccc9758 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs @@ -168,7 +168,7 @@ protected internal virtual void VisitLibrary(OutputLibrary library) /// /// The original . /// Null if it should be removed otherwise the modified version of the . - protected internal virtual ConstructorProvider? VisitConstructor(ConstructorProvider constructor) + protected virtual ConstructorProvider? VisitConstructor(ConstructorProvider constructor) { return constructor; } @@ -302,7 +302,7 @@ protected internal virtual FinallyExpression VisitFinallyExpression(FinallyExpre /// /// The original . /// Null if it should be removed otherwise the modified version of the . - protected internal virtual FieldProvider? VisitField(FieldProvider field) + protected virtual FieldProvider? VisitField(FieldProvider field) { return field; } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index c36686f637f..74f9fc6b971 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -84,6 +84,16 @@ public async Task AddInMemoryFile(TypeProvider type) await UpdateProject(document); } + internal void AnalyzeProviderReferenceMap(IReadOnlyList providers) + { + ProviderReferenceMapAnalyzer.Analyze(providers); + } + + internal void ApplyPreWriteAccessibility(IReadOnlyList providers) + { + ProviderReferenceMapAnalyzer.ApplyPreWriteAccessibility(providers); + } + private async Task UpdateProject(Document document) { var root = await document.GetSyntaxRootAsync(); @@ -278,10 +288,8 @@ public async Task PostProcessAsync() case Configuration.UnreferencedTypesHandlingOption.KeepAll: break; case Configuration.UnreferencedTypesHandlingOption.Internalize: - _project = await postProcessor.InternalizeAsync(_project); break; case Configuration.UnreferencedTypesHandlingOption.RemoveOrInternalize: - _project = await postProcessor.InternalizeAsync(_project); _project = await postProcessor.RemoveAsync(_project); break; } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs index dc42f801732..bef5dce85b4 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs @@ -9,7 +9,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Simplification; namespace Microsoft.TypeSpec.Generator { @@ -113,58 +112,6 @@ private async Task GetTypeSymbolsAsync(Compilation compilation, protected virtual bool ShouldIncludeDocument(Document document) => !GeneratedCodeWorkspace.IsGeneratedTestDocument(document); - /// - /// This method marks the "not publicly" referenced types as internal if they are previously defined as public. It will do this job in the following steps: - /// 1. This method will read all the public types defined in the given , and build a cache for those symbols - /// 2. Build a public reference map for those symbols - /// 3. Finds all the root symbols, please override the to control which document you would like to include - /// 4. Visit all the symbols starting from the root symbols following the reference map to get all unvisited symbols - /// 5. Change the accessibility of the unvisited symbols in step 4 to internal - /// - /// The project to process - /// The processed . is immutable, therefore this should usually be a new instance - public async Task InternalizeAsync(Project project) - { - var compilation = await project.GetCompilationAsync(); - if (compilation == null) - { - return project; - } - - // first get all the declared symbols - var definitions = await GetTypeSymbolsAsync(compilation, project, true); - // build the reference map - var referenceMap = - await new ReferenceMapBuilder(compilation, project).BuildPublicReferenceMapAsync( - definitions.DeclaredSymbols, definitions.DeclaredNodesCache); - // get the root symbols - var rootSymbols = await GetRootSymbolsAsync(project, definitions); - // traverse all the root and recursively add all the things we met - var publicSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); - - var symbolsToInternalize = definitions.DeclaredSymbols.Except(publicSymbols); - - var nodesToInternalize = new Dictionary(); - foreach (var symbol in symbolsToInternalize) - { - foreach (var node in definitions.DeclaredNodesCache[symbol]) - { - nodesToInternalize[node] = project.GetDocumentId(node.SyntaxTree)!; - } - } - - foreach (var (model, documentId) in nodesToInternalize) - { - project = MarkInternal(project, model, documentId); - } - - var modelNamesToRemove = - nodesToInternalize.Keys.Select(item => item.Identifier.Text); - project = await RemoveMethodsFromModelFactoryAsync(project, definitions, modelNamesToRemove.ToHashSet()); - - return project; - } - private async Task RemoveMethodsFromModelFactoryAsync(Project project, TypeSymbols definitions, HashSet namesToRemove) @@ -246,25 +193,32 @@ public async Task RemoveAsync(Project project) // find all the declarations, including non-public declared var definitions = await GetTypeSymbolsAsync(compilation, project, false); - // build reference map - var referenceMap = - await new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( - definitions.DeclaredSymbols, definitions.DocumentsCache); - // get root symbols - var rootSymbols = await GetRootSymbolsAsync(project, definitions); - // include model factory as a root symbol when doing the remove pass so that we are sure to include any internal - // helpers that are required by the model factory. - if (_modelFactorySymbol != null) + IEnumerable symbolsToRemove; + HashSet referencedSet; + if (ProviderReferenceMapAnalyzer.LatestResult is { } referenceMapResult) { - rootSymbols.Add(_modelFactorySymbol); + // The remove pass uses the same precomputed hybrid map to avoid scanning all generated + // documents with Roslyn while preserving custom-code references as roots. + symbolsToRemove = GetSymbolsByName(definitions.DeclaredSymbols, referenceMapResult.RemoveCandidates).ToArray(); + referencedSet = new HashSet(definitions.DeclaredSymbols.Except(symbolsToRemove), SymbolEqualityComparer.Default); } - // traverse the map to determine the declarations that we are about to remove, starting from root nodes - var referencedSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); - - referencedSymbols = AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols); - var referencedSet = new HashSet(referencedSymbols, SymbolEqualityComparer.Default); + else + { + var referenceMap = await new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( + definitions.DeclaredSymbols, definitions.DocumentsCache); + // Include model factory as a root symbol when doing the remove pass so that we are sure to include any internal + // helpers that are required by the model factory. + var rootSymbols = await GetRootSymbolsAsync(project, definitions); + if (_modelFactorySymbol != null) + { + rootSymbols.Add(_modelFactorySymbol); + } - var symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); + var referencedSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); + referencedSymbols = AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols); + referencedSet = new HashSet(referencedSymbols, SymbolEqualityComparer.Default); + symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); + } var nodesToRemove = new List(); foreach (var symbol in symbolsToRemove) @@ -276,6 +230,14 @@ public async Task RemoveAsync(Project project) nodesToRemove.AddRange(definitions.DeclaredNodesCache[symbol]); } + var modelNamesToRemove = nodesToRemove + .Select(static item => item.Identifier.Text) + .ToHashSet(StringComparer.Ordinal); + if (modelNamesToRemove.Count > 0) + { + project = await RemoveMethodsFromModelFactoryAsync(project, definitions, modelNamesToRemove); + } + // remove them one by one project = await RemoveModelsAsync(project, nodesToRemove); @@ -352,18 +314,19 @@ private static IEnumerable GetReferencedTypes(T definition, return Enumerable.Empty(); } - private Project MarkInternal(Project project, BaseTypeDeclarationSyntax declarationNode, DocumentId documentId) + private static IEnumerable GetSymbolsByName(IEnumerable symbols, HashSet names) { - var newNode = ChangeModifier(declarationNode, SyntaxKind.PublicKeyword, SyntaxKind.InternalKeyword); - var tree = declarationNode.SyntaxTree; - var document = project.GetDocument(documentId)!; - var newRoot = tree.GetRoot().ReplaceNode(declarationNode, newNode) - .WithAdditionalAnnotations(Simplifier.Annotation); - document = document.WithSyntaxRoot(newRoot); - return document.Project; + foreach (var symbol in symbols) + { + if (names.Contains(symbol.GetFullyQualifiedName())) + { + yield return symbol; + } + } } - private async Task RemoveModelsAsync(Project project, + private async Task RemoveModelsAsync( + Project project, IEnumerable unusedModels) { // accumulate the definitions from the same document together @@ -392,24 +355,6 @@ private async Task RemoveModelsAsync(Project project, return project; } - private static BaseTypeDeclarationSyntax ChangeModifier(BaseTypeDeclarationSyntax memberDeclaration, - SyntaxKind from, - SyntaxKind to) - { - var originalTokenInList = memberDeclaration.Modifiers.FirstOrDefault(token => token.IsKind(from)); - - // skip this if there is nothing to replace - if (originalTokenInList == default) - { - return memberDeclaration; - } - - var newToken = - SyntaxFactory.Token(originalTokenInList.LeadingTrivia, to, originalTokenInList.TrailingTrivia); - var newModifiers = memberDeclaration.Modifiers.Replace(originalTokenInList, newToken); - return memberDeclaration.WithModifiers(newModifiers); - } - private async Task RemoveModelsFromDocumentAsync(Project project, IEnumerable models) { @@ -479,7 +424,14 @@ private async Task RemoveInvalidUsings(Solution solution, DocumentId d if (invalidUsings.Count > 0) { + var leadingTrivia = invalidUsings[0].GetLeadingTrivia(); cu = cu.RemoveNodes(invalidUsings, SyntaxRemoveOptions.KeepNoTrivia)!; + if (leadingTrivia.Count > 0) + { + var firstToken = cu.GetFirstToken(includeZeroWidth: true); + cu = cu.ReplaceToken(firstToken, firstToken.WithLeadingTrivia(leadingTrivia.AddRange(firstToken.LeadingTrivia))); + } + solution = solution.WithDocumentSyntaxRoot(documentId, cu); } @@ -497,30 +449,37 @@ private async Task RemoveInvalidAttributes(Solution solution, Document return solution; } - var attributes = cu.DescendantNodes().OfType(); - var firstAttribute = attributes.FirstOrDefault(); + var attributeLists = cu.DescendantNodes().OfType().ToArray(); + var firstAttributeList = attributeLists.FirstOrDefault(); - var invalidAttributes = attributes - .Where(attr => attr.Attributes.Any(attribute => + var invalidAttributes = attributeLists + .SelectMany(static attr => attr.Attributes) + .Where(attribute => attribute.ArgumentList?.Arguments.Any(arg => arg.Expression is TypeOfExpressionSyntax typeOfExpr && - model.GetTypeInfo(typeOfExpr.Type).Type?.TypeKind == TypeKind.Error) == true)) + model.GetTypeInfo(typeOfExpr.Type).Type?.TypeKind == TypeKind.Error) == true) .ToHashSet(); if (invalidAttributes.Count > 0) { + var firstAttributeListRemoved = firstAttributeList != null && + firstAttributeList.Attributes.All(invalidAttributes.Contains); + var leadingTrivia = firstAttributeList?.GetLeadingTrivia(); cu = cu.RemoveNodes(invalidAttributes, SyntaxRemoveOptions.KeepNoTrivia)!; + var emptyAttributeLists = cu.DescendantNodes().OfType() + .Where(static list => list.Attributes.Count == 0) + .ToArray(); + cu = cu.RemoveNodes(emptyAttributeLists, SyntaxRemoveOptions.KeepNoTrivia)!; - if (invalidAttributes.Contains(firstAttribute!)) + if (firstAttributeListRemoved && leadingTrivia != null) { - var leadingTrivia = firstAttribute!.GetLeadingTrivia(); // Find where XML docs end and indentation begins var xmlDocTrivia = new List(); var lastXmlIndex = -1; - for (int i = 0; i < leadingTrivia.Count; i++) + for (int i = 0; i < leadingTrivia.Value.Count; i++) { - var trivia = leadingTrivia[i]; + var trivia = leadingTrivia.Value[i]; if (trivia.IsKind(SyntaxKind.SingleLineDocumentationCommentTrivia)) { lastXmlIndex = i; @@ -532,14 +491,14 @@ arg.Expression is TypeOfExpressionSyntax typeOfExpr && { for (int i = 0; i <= lastXmlIndex; i++) { - xmlDocTrivia.Add(leadingTrivia[i]); + xmlDocTrivia.Add(leadingTrivia.Value[i]); } // Include the newline after the last XML doc if present - if (lastXmlIndex + 1 < leadingTrivia.Count && - leadingTrivia[lastXmlIndex + 1].IsKind(SyntaxKind.EndOfLineTrivia)) + if (lastXmlIndex + 1 < leadingTrivia.Value.Count && + leadingTrivia.Value[lastXmlIndex + 1].IsKind(SyntaxKind.EndOfLineTrivia)) { - xmlDocTrivia.Add(leadingTrivia[lastXmlIndex + 1]); + xmlDocTrivia.Add(leadingTrivia.Value[lastXmlIndex + 1]); } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs new file mode 100644 index 00000000000..459f6cbc901 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs @@ -0,0 +1,2222 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text.RegularExpressions; +using Microsoft.TypeSpec.Generator.Expressions; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Statements; + +namespace Microsoft.TypeSpec.Generator +{ + internal static class ProviderReferenceMapAnalyzer + { + private static ProviderReferenceMapResult? _latestResult; + private static readonly ConditionalWeakTable, Dictionary> _simpleNameLookupCache = new(); + private static TypeProvider? _preWriteModelFactory; + private static MethodProvider[]? _preWriteModelFactoryMethods; + + public static ProviderReferenceMapResult? LatestResult => _latestResult; + public static bool PreWriteAccessibilityApplied { get; private set; } + + public static bool ShouldWriteProvider(TypeProvider provider) => + _latestResult?.RemoveCandidates.Contains(GetProviderTypeName(provider.Type)) != true; + + public static void ResetPreWriteAccessibility() + { + RestorePreWriteModelFactoryMethods(); + _latestResult = null; + PreWriteAccessibilityApplied = false; + } + + public static void ApplyPreWriteAccessibility(IReadOnlyList providers) + { + PreWriteAccessibilityApplied = false; + if (Configuration.UnreferencedTypesHandling == Configuration.UnreferencedTypesHandlingOption.KeepAll) + { + return; + } + + var (internalizeCandidates, publicizeCandidates) = GetPreWriteAccessibilityCandidates(providers); + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + if (internalizeCandidates.Contains(providerName)) + { + provider.PreserveXmlDocs(); + provider.Update(modifiers: MakeInternal(provider.DeclarationModifiers)); + } + else if (publicizeCandidates.Contains(providerName)) + { + provider.Update(modifiers: MakePublic(provider.DeclarationModifiers)); + } + } + + RemoveMethodsFromModelFactory(GetSimpleNames(internalizeCandidates)); + PreWriteAccessibilityApplied = true; + } + + public static void RestorePreWriteModelFactoryMethods() + { + if (_preWriteModelFactory == null || _preWriteModelFactoryMethods == null) + { + return; + } + + _preWriteModelFactory.Update(methods: _preWriteModelFactoryMethods); + _preWriteModelFactory = null; + _preWriteModelFactoryMethods = null; + } + + public static void Analyze(IReadOnlyList providers) + { + var generatedProviders = GetGeneratedProviders(providers); + var graph = BuildGraph(generatedProviders); + var publicGraph = BuildGraph(generatedProviders, publicOnly: true); + + var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(generatedProviders, graph.Nodes); + var apiBaselineGeneratedTypeRoots = GetApiBaselineGeneratedTypeRoots(graph.Nodes); + customPublicRoots.UnionWith(apiBaselineGeneratedTypeRoots); + var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarations(generatedProviders, graph.Nodes); + customPublicRoots.UnionWith(generatedPublicDeclarations); + var customCodeRemovalRoots = GetCustomCodeGeneratedTypeRoots(generatedProviders, graph.Nodes); + var customRemovalRoots = new HashSet(customCodeRemovalRoots, StringComparer.Ordinal); + customRemovalRoots.UnionWith(apiBaselineGeneratedTypeRoots); + customRemovalRoots.UnionWith(generatedPublicDeclarations); + var customInternalDeclarations = GetCustomCodeInternalGeneratedTypeDeclarations(generatedProviders, graph.Nodes); + var generatedInternalDeclarations = GetGeneratedInternalTypeDeclarations(generatedProviders, graph.Nodes); + + // Helper types are rooted after an initial reachability pass so unused infrastructure + // such as change-tracking dictionaries can still be removed when no reachable type needs them. + var generatedDiscriminatorBaseNames = GetGeneratedPersistableModelProxyTypeNames(generatedProviders, publicGraph.Nodes); + var (internalizeCandidates, publicizeCandidates, _) = GetAccessibilityCandidates( + providers, + generatedProviders, + graph, + publicGraph, + customPublicRoots, + customInternalDeclarations, + generatedInternalDeclarations, + generatedDiscriminatorBaseNames); + + // Body-only generated dependencies are needed to avoid deleting helper files, but they do + // not contribute to public API reachability for internalization. + AddGeneratedBodyReferences(providers, graph); + var removeCandidates = GetRemovalCandidates( + providers, + generatedProviders, + graph, + customRemovalRoots, + generatedDiscriminatorBaseNames); + + _latestResult = new ProviderReferenceMapResult( + internalizeCandidates, + publicizeCandidates, + removeCandidates); + RemoveMethodsFromModelFactory(GetSimpleNames(removeCandidates)); + } + + private static (HashSet InternalizeCandidates, HashSet PublicizeCandidates) GetPreWriteAccessibilityCandidates(IReadOnlyList providers) + { + var generatedProviders = GetGeneratedProviders(providers); + var graph = BuildGraph(generatedProviders); + var publicGraph = BuildGraph(generatedProviders, publicOnly: true); + var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(generatedProviders, graph.Nodes); + var apiBaselineGeneratedTypeRoots = GetApiBaselineGeneratedTypeRoots(graph.Nodes); + customPublicRoots.UnionWith(apiBaselineGeneratedTypeRoots); + var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarations(generatedProviders, graph.Nodes); + customPublicRoots.UnionWith(generatedPublicDeclarations); + var customInternalDeclarations = GetCustomCodeInternalGeneratedTypeDeclarations(generatedProviders, graph.Nodes); + var generatedInternalDeclarations = GetGeneratedInternalTypeDeclarations(generatedProviders, graph.Nodes); + var generatedDiscriminatorBaseNames = new HashSet(StringComparer.Ordinal); + + var (internalizeCandidates, publicizeCandidates, _) = GetAccessibilityCandidates( + providers, + generatedProviders, + graph, + publicGraph, + customPublicRoots, + customInternalDeclarations, + generatedInternalDeclarations, + generatedDiscriminatorBaseNames); + + return (internalizeCandidates, publicizeCandidates); + } + + private static (HashSet InternalizeCandidates, HashSet PublicizeCandidates, HashSet InternalizeHelperRoots) GetAccessibilityCandidates( + IReadOnlyList providers, + IReadOnlyList generatedProviders, + ProviderReferenceGraph graph, + ProviderReferenceGraph publicGraph, + HashSet customPublicRoots, + HashSet customInternalDeclarations, + HashSet generatedInternalDeclarations, + HashSet generatedDiscriminatorBaseNames) + { + var internalizeReferences = CloneReferences(publicGraph.References); + var internalizeRoots = GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: false, includeAdditionalRoots: true, includeUnionVariantRoots: false, publicClientRootsOnly: true); + if (ShouldUseUnionVariantFallbackRoots()) + { + AddUnionVariantRoots(internalizeRoots, providers, graph.Nodes); + } + + var generatedPublicReachable = GetReachableTypes(internalizeRoots, internalizeReferences); + AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, generatedPublicReachable, generatedDiscriminatorBaseNames); + internalizeRoots.UnionWith(customPublicRoots); + var internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); + AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, internalizeReachableWithoutHelpers, generatedDiscriminatorBaseNames); + internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); + var publicizeRoots = new HashSet(internalizeRoots, StringComparer.Ordinal); + var internalizeHelperRoots = GetHelperRootNames(generatedProviders, graph.Nodes, internalizeReachableWithoutHelpers); + internalizeRoots.UnionWith(internalizeHelperRoots); + var internalizeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, graph.Nodes, publicOnly: true); + var customInternalBoundaryNodes = GetCustomInternalBoundaryNodes(publicGraph, customInternalDeclarations); + var publicizeDeclaredNodes = GetPublicizeDeclaredNodes(generatedProviders, graph.Nodes, internalizeDeclaredNodes); + var generatedImplementationInternalDeclarations = GetGeneratedImplementationInternalTypeDeclarations(generatedInternalDeclarations); + var publicApiTraversalNodes = GetPublicApiTraversalNodes( + internalizeDeclaredNodes, + publicizeDeclaredNodes, + generatedInternalDeclarations, + generatedImplementationInternalDeclarations); + var publicizeReachable = GetReachableTypes(publicizeRoots, internalizeReferences, publicApiTraversalNodes); + var internalizeCandidates = GetInternalizeCandidates( + internalizeDeclaredNodes, + publicizeReachable, + customInternalDeclarations, + customInternalBoundaryNodes, + publicizeRoots); + var publicizeRootExclusions = GetRootNames( + providers, + graph.Nodes, + helperRoots: [], + includeModelFactory: true, + includeAdditionalRoots: true, + includeUnionVariantRoots: true, + publicClientRootsOnly: true); + var publicizeCandidates = GetPublicizeCandidates( + publicizeDeclaredNodes, + publicizeReachable, + customInternalDeclarations, + customInternalBoundaryNodes, + internalizeHelperRoots, + publicizeRootExclusions, + generatedInternalDeclarations, + publicizeRoots, + internalizeReferences, + generatedImplementationInternalDeclarations); + + return (internalizeCandidates, publicizeCandidates, internalizeHelperRoots); + } + + private static HashSet GetCustomInternalBoundaryNodes( + ProviderReferenceGraph publicGraph, + HashSet customInternalDeclarations) + { + var boundaryNodes = new HashSet(StringComparer.Ordinal); + foreach (var node in publicGraph.Nodes) + { + if (!publicGraph.References.TryGetValue(node, out var references)) + { + continue; + } + + if (references.Overlaps(customInternalDeclarations)) + { + boundaryNodes.Add(node); + } + } + + return boundaryNodes; + } + + private static HashSet GetPublicizeDeclaredNodes( + IReadOnlyList generatedProviders, + HashSet nodes, + HashSet internalizeDeclaredNodes) + { + var publicizeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, nodes, publicOnly: false); + publicizeDeclaredNodes.ExceptWith(internalizeDeclaredNodes); + return publicizeDeclaredNodes; + } + + private static HashSet GetPublicApiTraversalNodes( + HashSet internalizeDeclaredNodes, + HashSet publicizeDeclaredNodes, + HashSet generatedInternalDeclarations, + HashSet generatedImplementationInternalDeclarations) + { + var traversalNodes = new HashSet(StringComparer.Ordinal); + foreach (var node in internalizeDeclaredNodes) + { + if (generatedInternalDeclarations.Contains(node) || + generatedImplementationInternalDeclarations.Contains(node)) + { + continue; + } + + traversalNodes.Add(node); + } + + foreach (var node in publicizeDeclaredNodes) + { + if (!generatedImplementationInternalDeclarations.Contains(node)) + { + traversalNodes.Add(node); + } + } + + return traversalNodes; + } + + private static HashSet GetInternalizeCandidates( + HashSet internalizeDeclaredNodes, + HashSet publicizeReachable, + HashSet customInternalDeclarations, + HashSet customInternalBoundaryNodes, + HashSet publicizeRoots) + { + var candidates = new HashSet(StringComparer.Ordinal); + foreach (var node in internalizeDeclaredNodes) + { + if (!publicizeReachable.Contains(node) || + customInternalDeclarations.Contains(node) || + customInternalBoundaryNodes.Contains(node) && !publicizeRoots.Contains(node)) + { + candidates.Add(node); + } + } + + return candidates; + } + + private static HashSet GetPublicizeCandidates( + HashSet publicizeDeclaredNodes, + HashSet publicizeReachable, + HashSet customInternalDeclarations, + HashSet customInternalBoundaryNodes, + HashSet internalizeHelperRoots, + HashSet publicizeRootExclusions, + HashSet generatedInternalDeclarations, + HashSet publicizeRoots, + Dictionary> internalizeReferences, + HashSet generatedImplementationInternalDeclarations) + { + var candidates = new HashSet(StringComparer.Ordinal); + foreach (var node in publicizeDeclaredNodes) + { + if (customInternalDeclarations.Contains(node) || + customInternalBoundaryNodes.Contains(node) || + internalizeHelperRoots.Contains(node) || + publicizeRootExclusions.Contains(node) || + !publicizeReachable.Contains(node)) + { + continue; + } + + if (generatedInternalDeclarations.Contains(node) && !publicizeRoots.Contains(node)) + { + continue; + } + + if (!publicizeRoots.Contains(node) && + !HasPublicApiPredecessor(node, internalizeReferences, publicizeReachable, generatedImplementationInternalDeclarations)) + { + continue; + } + + candidates.Add(node); + } + + return candidates; + } + + private static HashSet GetRemovalCandidates( + IReadOnlyList providers, + IReadOnlyList generatedProviders, + ProviderReferenceGraph graph, + HashSet customRemovalRoots, + HashSet generatedDiscriminatorBaseNames) + { + var removeRoots = GetRootNames( + providers, + graph.Nodes, + helperRoots: [], + includeModelFactory: true, + includeAdditionalRoots: true, + includeUnionVariantRoots: false, + publicClientRootsOnly: false); + + removeRoots.UnionWith(customRemovalRoots); + AddMatchingNamesWithSimpleNameSuffix(removeRoots, "ReferenceType", graph.Nodes); + AddCustomCodeExtensionRoots(removeRoots, generatedProviders, graph.Nodes); + AddCustomizationBackedExtensionRoots(removeRoots, graph.Nodes); + AddCustomRequestHeaderExtensionsRoot(removeRoots, generatedProviders, graph.Nodes); + RemoveUnusedRequestHeaderExtensionsRoot(removeRoots, graph.References, providers); + + var removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); + AddDerivedModelReferences(providers, graph.Nodes, graph.References, removeReachableWithoutHelpers, generatedDiscriminatorBaseNames); + removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); + AddBasePreservedReferences(generatedProviders, graph.Nodes, graph.References, removeReachableWithoutHelpers); + + var removeHelperRoots = GetHelperRootNames(generatedProviders, graph.Nodes, removeReachableWithoutHelpers, graph.References); + removeRoots.UnionWith(removeHelperRoots); + + var removeReachable = GetReachableTypes(removeRoots, graph.References); + AddBasePreservedReferences(generatedProviders, graph.Nodes, graph.References, removeReachable); + + var removeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, graph.Nodes, publicOnly: false); + removeDeclaredNodes.ExceptWith(removeReachable); + return removeDeclaredNodes; + } + + private static HashSet GetCustomCodeGeneratedTypeRoots(IReadOnlyList providers, HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + AddCustomCodeViewRoots(roots, customCodeView, generatedTypeNames, publicOnly: false); + } + + return roots; + } + + private static HashSet GetCustomCodePublicGeneratedTypeRoots(IReadOnlyList providers, HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + if (!customCodeView.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + AddCustomCodeViewRoots(roots, customCodeView, generatedTypeNames, publicOnly: true); + } + + return roots; + } + + private static IEnumerable GetCustomCodeViews(IReadOnlyList providers) + { + var visited = new HashSet(StringComparer.Ordinal); + var modelFactoryCustomCodeView = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value.CustomCodeView; + if (modelFactoryCustomCodeView != null && visited.Add(GetCustomCodeViewIdentity(modelFactoryCustomCodeView))) + { + yield return modelFactoryCustomCodeView; + } + + foreach (var provider in providers) + { + var customCodeView = provider.CustomCodeView; + if (customCodeView == null || !visited.Add(GetCustomCodeViewIdentity(customCodeView))) + { + continue; + } + + yield return customCodeView; + } + + foreach (var customTypeProvider in CodeModelGenerator.Instance.SourceInputModel.GetCustomizationTypeProviders()) + { + if (visited.Add(GetCustomCodeViewIdentity(customTypeProvider))) + { + yield return customTypeProvider; + } + } + } + + private static string GetCustomCodeViewIdentity(TypeProvider customCodeView) => + customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider + ? namedTypeSymbolProvider.MetadataName + : GetProviderTypeName(customCodeView.Type); + + private static void AddCustomRequestHeaderExtensionsRoot(HashSet roots, IReadOnlyList providers, HashSet nodes) + { + if (!HasCustomRequestHeaderExtensionsReference(providers)) + { + return; + } + + AddMatchingNamesWithSimpleNameSuffix(roots, "RequestHeaderExtensions", nodes); + AddMatchingNamesWithSimpleNameSuffix(roots, "RequestHeadersExtensions", nodes); + } + + private static void AddCustomCodeExtensionRoots(HashSet roots, IReadOnlyList providers, HashSet nodes) + { + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + AddMatchingName(roots, $"{GetCustomCodeViewSimpleName(customCodeView)}Extensions", nodes); + } + } + + private static string GetCustomCodeViewSimpleName(TypeProvider customCodeView) => + customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider + ? namedTypeSymbolProvider.MetadataSimpleName + : customCodeView.Type.Name; + + private static void AddCustomizationBackedExtensionRoots(HashSet roots, HashSet nodes) + { + foreach (var node in nodes) + { + var simpleName = GetSimpleName(node); + if (!simpleName.EndsWith("Extensions", StringComparison.Ordinal)) + { + continue; + } + + var namespaceName = GetNamespaceName(node); + if (namespaceName == null) + { + continue; + } + + var customTypeName = simpleName.Substring(0, simpleName.Length - "Extensions".Length); + if (CodeModelGenerator.Instance.SourceInputModel.FindForTypeInCustomization(namespaceName, customTypeName) != null) + { + roots.Add(node); + } + } + } + + private static void AddCustomCodeViewRoots(HashSet roots, TypeProvider customCodeView, HashSet generatedTypeNames, bool publicOnly) + { + if (customCodeView is NamedTypeSymbolProvider) + { + AddProviderBodyDependencyTypes(roots, customCodeView.SignatureDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true); + } + + AddTypeReference(roots, customCodeView.BaseType, generatedTypeNames); + AddProviderBodyDependencyTypes(roots, customCodeView.SignatureDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true); + if (!publicOnly) + { + AddAttributes(roots, customCodeView.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); + AddMatchingName(roots, $"{GetCustomCodeViewSimpleName(customCodeView)}Extensions", generatedTypeNames); + } + + foreach (var implementedType in customCodeView.Implements) + { + AddTypeReference(roots, implementedType, generatedTypeNames); + } + + foreach (var constructor in customCodeView.Constructors) + { + if (publicOnly && !IsPublic(constructor.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(roots, constructor.Signature, generatedTypeNames, serializationProviderNamesByType: null, includeAttributes: !publicOnly); + } + + foreach (var method in customCodeView.Methods) + { + if (publicOnly && !IsPublic(method.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(roots, method.Signature, generatedTypeNames, serializationProviderNamesByType: null, includeAttributes: !publicOnly); + } + + foreach (var property in customCodeView.Properties) + { + if (publicOnly && !IsPublic(property.Modifiers)) + { + continue; + } + + AddTypeReference(roots, property.Type, generatedTypeNames); + AddTypeReference(roots, property.ExplicitInterface, generatedTypeNames); + if (!publicOnly) + { + AddAttributes(roots, property.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); + } + } + + foreach (var field in customCodeView.Fields) + { + if (publicOnly && !IsPublic(field.Modifiers)) + { + continue; + } + + AddTypeReference(roots, field.Type, generatedTypeNames); + if (!publicOnly) + { + AddAttributes(roots, field.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); + } + } + } + + private static HashSet GetApiBaselineGeneratedTypeRoots(HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + var projectDirectory = CodeModelGenerator.Instance.Configuration.ProjectDirectory; + if (string.IsNullOrEmpty(projectDirectory)) + { + return roots; + } + + var apiDirectory = Path.GetFullPath(Path.Combine(projectDirectory, "..", "api")); + if (!Directory.Exists(apiDirectory)) + { + return roots; + } + + var apiText = string.Join("\n", Directory.GetFiles(apiDirectory, "*.cs", SearchOption.AllDirectories).Select(File.ReadAllText)); + var apiDeclaredTypeNames = GetApiDeclaredTypeNames(apiText); + foreach (var fullName in generatedTypeNames) + { + var simpleName = StripGenericArity(GetSimpleName(fullName)); + var normalizedFullName = StripGenericArity(fullName); + if (!ContainsApiTypeReference(apiText, apiDeclaredTypeNames, normalizedFullName, simpleName)) + { + continue; + } + + roots.Add(fullName); + } + + return roots; + } + + private static HashSet GetApiDeclaredTypeNames(string apiText) + { + var declaredTypeNames = new HashSet(StringComparer.Ordinal); + string? currentNamespace = null; + foreach (var line in apiText.Split('\n')) + { + var namespaceMatch = Regex.Match(line, @"^namespace\s+([\w.]+)\s*\{?\s*$"); + if (namespaceMatch.Success) + { + currentNamespace = namespaceMatch.Groups[1].Value; + continue; + } + + if (currentNamespace == null) + { + continue; + } + + var declarationMatch = Regex.Match(line, @"^ \S.*?\b(class|struct|interface|enum)\s+([A-Za-z_][A-Za-z0-9_]*)(?!\s*<)(?!\w)"); + if (declarationMatch.Success) + { + declaredTypeNames.Add($"{currentNamespace}.{declarationMatch.Groups[2].Value}"); + } + } + + return declaredTypeNames; + } + + private static bool ContainsApiTypeReference(string apiText, HashSet apiDeclaredTypeNames, string fullName, string simpleName) + { + var fullNamePattern = $@"(? GetCustomCodeInternalGeneratedTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) + { + var declarations = new HashSet(StringComparer.Ordinal); + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + if (!customCodeView.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal)) + { + continue; + } + + if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) + { + AddMatchingName(declarations, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); + } + else + { + AddTypeReference(declarations, customCodeView.Type, generatedTypeNames); + } + } + + return declarations; + } + + private static HashSet GetGeneratedPersistableModelProxyTypeNames(IReadOnlyList providers, HashSet generatedTypeNames) + { + var proxyTypes = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (provider.Attributes.Any(static attribute => IsAttributeNamed(attribute, "PersistableModelProxy"))) + { + AddTypeReference(proxyTypes, provider.Type, generatedTypeNames); + } + } + + return proxyTypes; + } + + private static HashSet GetGeneratedInternalTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) + => GetGeneratedTypeDeclarationsByLastContractAccessibility(providers, generatedTypeNames, TypeSignatureModifiers.Internal); + + private static HashSet GetGeneratedPublicTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) + => GetGeneratedTypeDeclarationsByLastContractAccessibility(providers, generatedTypeNames, TypeSignatureModifiers.Public); + + private static HashSet GetGeneratedTypeDeclarationsByLastContractAccessibility( + IReadOnlyList providers, + HashSet generatedTypeNames, + TypeSignatureModifiers accessibility) + { + var declarations = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (provider.LastContractView?.DeclarationModifiers.HasFlag(accessibility) != true) + { + continue; + } + + AddTypeReference(declarations, provider.Type, generatedTypeNames); + } + + return declarations; + } + + private static HashSet GetGeneratedImplementationInternalTypeDeclarations(HashSet generatedInternalDeclarations) + { + var implementationDeclarations = new HashSet(StringComparer.Ordinal); + foreach (var name in generatedInternalDeclarations) + { + if (GetSimpleName(name).StartsWith("Internal", StringComparison.Ordinal)) + { + implementationDeclarations.Add(name); + } + } + + return implementationDeclarations; + } + + private static HashSet GetSimpleNames(HashSet names) + { + var simpleNames = new HashSet(StringComparer.Ordinal); + foreach (var name in names) + { + simpleNames.Add(GetSimpleName(name)); + } + + return simpleNames; + } + + private static ProviderReferenceGraph BuildGraph(IReadOnlyList generatedProviders, bool publicOnly = false) + { + // Each generated provider becomes a node, and provider metadata supplies the edges: + // inheritance, signatures, properties, fields, nested/serialization providers, attributes, + // and selected implementation dependencies. This avoids parsing generated C# just to + // rediscover generated-to-generated references. + var serializationProviderNamesByType = GetSerializationProviderNamesByType(generatedProviders); + IReadOnlyDictionary? serializationReferenceNamesByType = publicOnly ? null : serializationProviderNamesByType; + var nodes = new HashSet(StringComparer.Ordinal); + var references = new Dictionary>(StringComparer.Ordinal); + foreach (var provider in generatedProviders) + { + var providerName = GetProviderTypeName(provider.Type); + if (nodes.Add(providerName)) + { + references.Add(providerName, new HashSet(StringComparer.Ordinal)); + } + } + + foreach (var provider in generatedProviders) + { + var current = GetProviderTypeName(provider.Type); + AddTypeReference(references[current], provider.Type, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], provider.BaseType, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], provider.DeclaringTypeProvider?.Type, nodes, serializationReferenceNamesByType); + + if (IsKept(provider.Type, CodeModelGenerator.Instance.NonRootTypes, nodes)) + { + continue; + } + + // Model factory signatures mention many models. The existing Roslyn post-processor + // removes factory methods for unreachable models, so model factory should only + // contribute helper dependencies, not model reachability edges. + if (IsModelFactoryProvider(provider)) + { + continue; + } + + foreach (var implementedType in provider.Implements) + { + AddTypeReference(references[current], implementedType, nodes, serializationReferenceNamesByType); + } + + if (!publicOnly) + { + foreach (var nestedType in provider.NestedTypes) + { + AddTypeReference(references[current], nestedType.Type, nodes, serializationReferenceNamesByType); + } + } + + if (!publicOnly) + { + foreach (var serializationProvider in provider.SerializationProviders) + { + AddTypeReference(references[current], serializationProvider.Type, nodes, serializationReferenceNamesByType); + } + } + + foreach (var property in provider.Properties) + { + if (publicOnly && !IsPublic(property.Modifiers)) + { + continue; + } + + AddTypeReference(references[current], property.Type, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], property.ExplicitInterface, nodes, serializationReferenceNamesByType); + if (!publicOnly) + { + AddAttributes(references[current], property.Attributes, nodes, serializationReferenceNamesByType, includeArguments: false); + } + } + + foreach (var field in provider.Fields) + { + if (publicOnly && !field.Modifiers.HasFlag(FieldModifiers.Public)) + { + continue; + } + + AddTypeReference(references[current], field.Type, nodes, serializationReferenceNamesByType); + if (!publicOnly) + { + AddAttributes(references[current], field.Attributes, nodes, serializationReferenceNamesByType, includeArguments: false); + } + } + + foreach (var constructor in provider.Constructors) + { + if (publicOnly && !IsPublic(constructor.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(references[current], constructor.Signature, nodes, serializationReferenceNamesByType, includeAttributes: !publicOnly, includeAttributeArguments: false); + } + + foreach (var method in provider.Methods) + { + if (method.IsMethodSuppressed()) + { + continue; + } + + if (publicOnly && !IsPublic(method.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(references[current], method.Signature, nodes, serializationReferenceNamesByType, includeAttributes: !publicOnly, includeAttributeArguments: false); + if (!publicOnly) + { + AddTypeReference(references[current], GetCollectionDefinitionType(method), nodes, serializationReferenceNamesByType); + } + } + } + + return new ProviderReferenceGraph(nodes, references); + } + + private static Dictionary GetSerializationProviderNamesByType(IReadOnlyList generatedProviders) + { + var namesByType = new Dictionary>(StringComparer.Ordinal); + foreach (var provider in generatedProviders) + { + if (provider.SerializationProviders.Count == 0) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!namesByType.TryGetValue(providerName, out var serializationProviderNames)) + { + serializationProviderNames = new HashSet(StringComparer.Ordinal); + namesByType.Add(providerName, serializationProviderNames); + } + + foreach (var serializationProvider in provider.SerializationProviders) + { + serializationProviderNames.Add(GetProviderTypeName(serializationProvider.Type)); + } + } + + var result = new Dictionary(StringComparer.Ordinal); + foreach (var (providerName, serializationProviderNames) in namesByType) + { + result.Add(providerName, [.. serializationProviderNames]); + } + + return result; + } + + private static CSharpType? GetCollectionDefinitionType(MethodProvider method) + { + var property = method.GetType().GetProperty("CollectionDefinition"); + return property?.GetValue(method) is TypeProvider collectionDefinition + ? collectionDefinition.Type + : null; + } + + private static bool IsPublic(MethodSignatureModifiers modifiers) => modifiers.HasFlag(MethodSignatureModifiers.Public); + private static bool IsPublic(FieldModifiers modifiers) => modifiers.HasFlag(FieldModifiers.Public); + + private static TypeSignatureModifiers MakeInternal(TypeSignatureModifiers modifiers) + => (modifiers & ~(TypeSignatureModifiers.Public | TypeSignatureModifiers.Private | TypeSignatureModifiers.Protected)) | TypeSignatureModifiers.Internal; + + private static TypeSignatureModifiers MakePublic(TypeSignatureModifiers modifiers) + => (modifiers & ~(TypeSignatureModifiers.Internal | TypeSignatureModifiers.Private | TypeSignatureModifiers.Protected)) | TypeSignatureModifiers.Public; + + private static Dictionary> CloneReferences(IReadOnlyDictionary> references) + { + var clone = new Dictionary>(StringComparer.Ordinal); + foreach (var (name, referencedNames) in references) + { + clone.Add(name, new HashSet(referencedNames, StringComparer.Ordinal)); + } + + return clone; + } + + private static void AddDerivedModelReferences( + IReadOnlyList providers, + HashSet nodes, + Dictionary> references, + HashSet publicBaseModels, + HashSet generatedDiscriminatorBaseNames) + { + var modelProviders = new List(); + var discriminatorProviders = new List(); + var discriminatorBaseNames = new HashSet(StringComparer.Ordinal); + foreach (var provider in providers) + { + if (provider is not ModelProvider modelProvider || + !modelProvider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + modelProviders.Add(modelProvider); + + if (modelProvider.DiscriminatorProperty != null) + { + discriminatorBaseNames.Add(GetProviderTypeName(modelProvider.Type)); + } + + if (!modelProvider.IsUnknownDiscriminatorModel && + (modelProvider.DiscriminatorProperty != null || modelProvider.DiscriminatorValue != null)) + { + discriminatorProviders.Add(modelProvider); + } + } + + discriminatorBaseNames.UnionWith(generatedDiscriminatorBaseNames); + var addedReference = true; + while (addedReference) + { + addedReference = false; + foreach (var provider in discriminatorProviders) + { + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName)) + { + continue; + } + + if (!publicBaseModels.Contains(providerName)) + { + continue; + } + + foreach (var derivedModel in provider.DerivedModels) + { + if (derivedModel.IsUnknownDiscriminatorModel || + !derivedModel.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var before = references[providerName].Count; + AddTypeReference(references[providerName], derivedModel.Type, nodes); + var derivedName = GetProviderTypeName(derivedModel.Type); + if (nodes.Contains(derivedName) && publicBaseModels.Add(derivedName) || references[providerName].Count != before) + { + addedReference = true; + } + } + } + + foreach (var provider in modelProviders) + { + if (provider.IsUnknownDiscriminatorModel || + !provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName)) + { + continue; + } + + var baseTypeName = provider.BaseType == null ? null : GetProviderTypeName(provider.BaseType); + if (baseTypeName == null || + !discriminatorBaseNames.Contains(baseTypeName) || + !nodes.Contains(baseTypeName) || + !publicBaseModels.Contains(baseTypeName)) + { + continue; + } + + var before = references[baseTypeName].Count; + references[baseTypeName].Add(providerName); + if (publicBaseModels.Add(providerName) || references[baseTypeName].Count != before) + { + addedReference = true; + } + } + } + } + + private static void AddBasePreservedReferences( + IReadOnlyList providers, + HashSet nodes, + IReadOnlyDictionary> references, + HashSet reachableTypes) + { + var basePreservedRoots = new HashSet(StringComparer.Ordinal); + var addedRoot = true; + while (addedRoot) + { + addedRoot = false; + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName) || reachableTypes.Contains(providerName) || basePreservedRoots.Contains(providerName)) + { + continue; + } + + var baseTypeName = provider.BaseType == null ? null : GetProviderTypeName(provider.BaseType); + if (baseTypeName == null || !reachableTypes.Contains(baseTypeName)) + { + continue; + } + + if (basePreservedRoots.Add(providerName)) + { + addedRoot = true; + } + } + + if (addedRoot) + { + reachableTypes.UnionWith(GetReachableTypes(basePreservedRoots, references)); + } + } + } + + private static IReadOnlyList GetGeneratedProviders(IReadOnlyList providers) + { + var generatedProviders = new List(); + foreach (var provider in providers) + { + AddGeneratedProvider(generatedProviders, provider); + } + + return generatedProviders; + } + + private static void AddGeneratedProvider(List generatedProviders, TypeProvider provider) + { + generatedProviders.Add(provider); + foreach (var nestedType in provider.NestedTypes) + { + AddGeneratedProvider(generatedProviders, nestedType); + } + + foreach (var serializationProvider in provider.SerializationProviders) + { + AddGeneratedProvider(generatedProviders, serializationProvider); + } + } + + private static void AddGeneratedBodyReferences(IReadOnlyList providers, ProviderReferenceGraph graph) + { + foreach (var (provider, isSerializationProvider) in GetBodyReferenceProviders(providers)) + { + if (IsModelFactoryProvider(provider) || + !IsGeneratedBodyReferenceCandidate(provider, isSerializationProvider)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!graph.Nodes.Contains(providerName)) + { + continue; + } + AddHelperDependencies(graph.References[providerName], provider.HelperDependencyTypes, graph.Nodes, referencedNames: null); + AddProviderBodyDependencyTypes( + graph.References[providerName], + GetNonEnumStructuredBodyReferenceTypes(provider, graph.Nodes), + graph.Nodes); + AddProviderBodyDependencyTypes(graph.References[providerName], provider.BodyDependencyTypes, graph.Nodes); + AddProviderInfrastructureReferences(graph.References[providerName], provider, graph.Nodes); + } + } + + private static IReadOnlyList GetNonEnumStructuredBodyReferenceTypes(TypeProvider provider, HashSet nodes) + { + var references = new List(); + foreach (var dependency in CollectStructuredBodyReferenceTypes(provider)) + { + if (!IsEnumProviderDependency(dependency, nodes)) + { + references.Add(dependency); + } + } + + return references; + } + + private static void AddProviderInfrastructureReferences(HashSet references, TypeProvider provider, HashSet nodes) + { + AddMatchingName(references, "ProviderConstants", nodes); + AddMatchingName(references, "TypeFormatters", nodes); + + if (provider.SerializationProviders.Count > 0) + { + AddSerializationExtensionReferences(references, provider, nodes); + } + + if (IsSerializationProvider(provider)) + { + AddMatchingName(references, "Optional", nodes); + AddMatchingName(references, "Utf8JsonRequestContent", nodes); + AddMatchingName(references, "ModelSerializationExtensions", nodes); + AddSerializationExtensionReferences(references, provider, nodes); + } + + foreach (var method in provider.Methods) + { + AddMethodInfrastructureReferences(references, method, nodes); + } + } + + private static void AddSerializationExtensionReferences(HashSet references, TypeProvider provider, HashSet nodes) + { + AddSerializationExtensionReferences(references, provider.Type, nodes); + AddSerializationExtensionReferences(references, provider.BaseType, nodes); + foreach (var implementedType in provider.Implements) + { + AddSerializationExtensionReferences(references, implementedType, nodes); + } + + foreach (var property in provider.Properties) + { + AddSerializationExtensionReferences(references, property.Type, nodes); + } + + foreach (var field in provider.Fields) + { + AddSerializationExtensionReferences(references, field.Type, nodes); + } + + foreach (var constructor in provider.Constructors) + { + AddSerializationExtensionReferences(references, constructor.Signature.ReturnType, nodes); + foreach (var parameter in constructor.Signature.Parameters) + { + AddSerializationExtensionReferences(references, parameter.Type, nodes); + } + } + + foreach (var method in provider.Methods) + { + AddSerializationExtensionReferences(references, method.Signature.ReturnType, nodes); + foreach (var parameter in method.Signature.Parameters) + { + AddSerializationExtensionReferences(references, parameter.Type, nodes); + } + } + } + + private static void AddSerializationExtensionReferences(HashSet references, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + AddMatchingName(references, $"{type.Name}Extensions", nodes); + foreach (var argument in type.Arguments) + { + AddSerializationExtensionReferences(references, argument, nodes); + } + } + + private static void AddMethodInfrastructureReferences(HashSet references, MethodProvider method, HashSet nodes) + { + AddReturnTypeInfrastructureReferences(references, method.Signature.ReturnType, nodes); + foreach (var parameter in method.Signature.Parameters) + { + AddRequestContentInfrastructureReferences(references, parameter.Type, nodes); + } + } + + private static void AddReturnTypeInfrastructureReferences(HashSet references, CSharpType? returnType, HashSet nodes) + { + var type = UnwrapTask(returnType); + if (type == null) + { + return; + } + + var typeName = StripGenericArity(type.Name); + if (string.Equals(typeName, "Pageable", StringComparison.Ordinal)) + { + AddMatchingName(references, "PageableWrapper", nodes); + } + else if (string.Equals(typeName, "AsyncPageable", StringComparison.Ordinal)) + { + AddMatchingName(references, "AsyncPageableWrapper", nodes); + } + else if (string.Equals(typeName, "ArmOperation", StringComparison.Ordinal)) + { + AddMatchingNamesWithSimpleNameSuffix(references, "ArmOperation", nodes); + AddMatchingNamesWithSimpleNameSuffix(references, "OperationSource", nodes); + if (type.Arguments.Count > 0) + { + AddMatchingName(references, $"{BuildOperationSourceTypeName(type.Arguments[0])}OperationSource", nodes); + } + } + } + + private static void AddRequestContentInfrastructureReferences(HashSet references, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + if (string.Equals(type.Name, "RequestContent", StringComparison.Ordinal)) + { + AddMatchingName(references, "BinaryContentHelper", nodes); + AddMatchingName(references, "Utf8JsonRequestContent", nodes); + } + + foreach (var argument in type.Arguments) + { + AddRequestContentInfrastructureReferences(references, argument, nodes); + } + } + + private static CSharpType? UnwrapTask(CSharpType? type) + { + var typeName = type == null ? null : StripGenericArity(type.Name); + if ((string.Equals(typeName, "Task", StringComparison.Ordinal) || + string.Equals(typeName, "ValueTask", StringComparison.Ordinal)) && + type?.Arguments.Count > 0) + { + return type.Arguments[0]; + } + + return type; + } + + private static string BuildOperationSourceTypeName(CSharpType type) + { + var argumentNames = string.Join("", type.Arguments.Select(BuildOperationSourceTypeName)); + return $"{type.Name}{(argumentNames.Length > 0 ? "Of" : string.Empty)}{argumentNames}"; + } + + private static IReadOnlyList CollectStructuredBodyReferenceTypes(TypeProvider provider) + { + var references = new HashSet(); + var visited = new HashSet(ReferenceEqualityComparer.Instance); + + foreach (var field in provider.Fields) + { + CollectStructuredBodyReferenceTypes(field.InitializationValue, references, visited); + } + + foreach (var property in provider.Properties) + { + CollectStructuredBodyReferenceTypes(property.Body, references, visited); + } + + foreach (var constructor in provider.Constructors) + { + CollectStructuredBodyReferenceTypes(constructor.BodyExpression, references, visited); + CollectStructuredBodyReferenceTypes(constructor.BodyStatements, references, visited); + } + + foreach (var method in provider.Methods) + { + if (method.IsMethodSuppressed()) + { + continue; + } + + CollectStructuredBodyReferenceTypes(method.BodyExpression, references, visited); + CollectStructuredBodyReferenceTypes(method.BodyStatements, references, visited); + } + + return [.. references]; + } + + private static void CollectStructuredBodyReferenceTypes(object? value, HashSet references, HashSet visited) + { + switch (value) + { + case null: + case string: + case FormattableString: + return; + } + + if (!value.GetType().IsValueType && !visited.Add(value)) + { + return; + } + + switch (value) + { + case CSharpType type: + references.Add(type); + return; + case Type type: + references.Add(type); + return; + case ParameterProvider parameter: + references.Add(parameter.Type); + CollectStructuredBodyReferenceTypes(parameter.DefaultValue, references, visited); + CollectStructuredBodyReferenceTypes(parameter.InitializationValue, references, visited); + return; + case MethodSignatureBase signature: + CollectStructuredBodyReferenceTypes(signature.ReturnType, references, visited); + CollectStructuredBodyReferenceTypes(signature.Parameters, references, visited); + return; + case KeyValuePair positionalArgument: + CollectStructuredBodyReferenceTypes(positionalArgument.Value, references, visited); + return; + case FieldProvider field: + references.Add(field.Type); + CollectStructuredBodyReferenceTypes(field.InitializationValue, references, visited); + return; + } + + if (IsStructuredBodyReferenceObject(value)) + { + foreach (var property in value.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance)) + { + if (property.GetIndexParameters().Length > 0) + { + continue; + } + + CollectStructuredBodyReferenceTypes(property.GetValue(value), references, visited); + } + + return; + } + + if (value is not IEnumerable values) + { + return; + } + + foreach (var item in values) + { + CollectStructuredBodyReferenceTypes(item, references, visited); + } + } + + private static bool IsEnumProviderDependency(CSharpType dependency, HashSet nodes) + { + var providerName = GetProviderTypeName(dependency); + if (!nodes.Contains(providerName)) + { + return false; + } + + foreach (var provider in CodeModelGenerator.Instance.OutputLibrary.TypeProviders) + { + if (provider is EnumProvider && + string.Equals(GetProviderTypeName(provider.Type), providerName, StringComparison.Ordinal)) + { + return true; + } + } + + return false; + } + + private static bool IsStructuredBodyReferenceObject(object value) => + value is ValueExpression || + value is MethodBodyStatement || + value is PropertyBody; + + private static void AddProviderBodyDependencyTypes( + HashSet references, + IReadOnlyList dependencies, + HashSet nodes, + bool includeSimpleNameReferences = false) + { + foreach (var dependency in dependencies) + { + AddProviderBodyDependencyType(references, dependency, nodes, includeSimpleNameReferences); + } + } + + private static void AddProviderBodyDependencyType( + HashSet references, + CSharpType? dependency, + HashSet nodes, + bool includeSimpleNameReferences) + { + if (dependency == null) + { + return; + } + + AddTypeReference(references, dependency, nodes); + if (includeSimpleNameReferences) + { + AddMatchingName(references, dependency.Name, nodes); + } + AddMatchingName(references, $"{dependency.Name}Extensions", nodes); + + foreach (var argument in dependency.Arguments) + { + AddProviderBodyDependencyType(references, argument, nodes, includeSimpleNameReferences); + } + } + + private static IReadOnlyList<(TypeProvider Provider, bool IsSerializationProvider)> GetBodyReferenceProviders(IReadOnlyList providers) + { + var bodyReferenceProviders = new List<(TypeProvider Provider, bool IsSerializationProvider)>(); + foreach (var provider in providers) + { + bodyReferenceProviders.Add((provider, false)); + foreach (var serializationProvider in provider.SerializationProviders) + { + bodyReferenceProviders.Add((serializationProvider, true)); + } + } + + return bodyReferenceProviders; + } + + private static bool IsGeneratedBodyReferenceCandidate(TypeProvider provider, bool isSerializationProvider) + { + if (provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) + { + return true; + } + + return provider.IsClientProvider || + isSerializationProvider || + provider.IncludeGeneratedBodyReferences || + provider.HelperDependencyTypes.Count > 0 || + provider.BodyDependencyTypes.Count > 0; + } + + private static HashSet GetRootNames( + IReadOnlyList providers, + HashSet nodes, + HashSet helperRoots, + bool includeModelFactory, + bool includeAdditionalRoots, + bool includeUnionVariantRoots, + bool publicClientRootsOnly) + { + var generator = CodeModelGenerator.Instance; + var roots = new HashSet(StringComparer.Ordinal); + var modelFactoryName = GetProviderTypeName(generator.OutputLibrary.ModelFactory.Value.Type); + + foreach (var provider in providers) + { + var name = GetProviderTypeName(provider.Type); + if (IsClientProviderRoot(provider, publicClientRootsOnly) || + includeAdditionalRoots && IsAdditionalRootProvider(provider, generator.AdditionalRootTypes, nodes) || + includeModelFactory && string.Equals(name, modelFactoryName, StringComparison.Ordinal) || + includeModelFactory && helperRoots.Contains(name)) + { + roots.Add(name); + } + } + + AddLastContractModelFactorySignatureRoots(providers, roots, nodes); + + if (!includeUnionVariantRoots) + { + return roots; + } + + AddUnionVariantRoots(roots, providers, nodes); + + return roots; + } + + private static void AddLastContractModelFactorySignatureRoots(IReadOnlyList providers, HashSet roots, HashSet nodes) + { + foreach (var provider in providers) + { + if (!IsModelFactoryProvider(provider)) + { + continue; + } + + foreach (var method in provider.LastContractView?.Methods ?? []) + { + if (!method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Public) || + IsImplementationOnlyModelFactoryMethod(method)) + { + continue; + } + + AddTypeReference(roots, method.Signature.ReturnType, nodes); + foreach (var parameter in method.Signature.Parameters) + { + AddTypeReference(roots, parameter.Type, nodes); + } + } + } + } + + private static void AddUnionVariantRoots(HashSet roots, IReadOnlyList providers, HashSet nodes) + { + var unionVariantTypesToKeep = CodeModelGenerator.Instance.TypeFactory.UnionVariantTypesToKeep; + foreach (var provider in GetGeneratedProviders(providers)) + { + if (!unionVariantTypesToKeep.Contains(provider.Type.Name) || + string.Equals(provider.Type.Namespace, "TypeSpec.Http", StringComparison.Ordinal)) + { + continue; + } + + AddMatchingName(roots, GetProviderTypeName(provider.Type), nodes); + } + } + + private static bool ShouldUseUnionVariantFallbackRoots() => + !HasApiBaselineDirectory() && + CodeModelGenerator.Instance.SourceInputModel.LastContract == null; + + private static bool IsImplementationOnlyModelFactoryMethod(MethodProvider method) + { + var returnType = method.Signature.ReturnType; + if (returnType == null) + { + return true; + } + + var returnTypeName = GetSimpleName(GetProviderTypeName(returnType)); + return returnTypeName.StartsWith("Paged", StringComparison.Ordinal) || + returnTypeName.EndsWith("Request", StringComparison.Ordinal); + } + + private static void RemoveMethodsFromModelFactory(HashSet namesToRemove) + { + if (namesToRemove.Count == 0) + { + return; + } + + var modelFactory = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value; + _preWriteModelFactory = modelFactory; + _preWriteModelFactoryMethods ??= [.. modelFactory.Methods]; + var methodsToKeep = new List(); + foreach (var method in modelFactory.Methods) + { + if (!namesToRemove.Contains(method.Signature.Name)) + { + methodsToKeep.Add(method); + } + } + + modelFactory.Update(methods: methodsToKeep); + } + + private static HashSet GetPostProcessorDeclaredNodes(IReadOnlyList providers, HashSet nodes, bool publicOnly) + { + var generator = CodeModelGenerator.Instance; + var excludedNames = generator.NonRootTypes; + var declaredNodes = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (IsModelFactoryProvider(provider)) + { + continue; + } + + if (publicOnly && !provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var name = GetProviderTypeName(provider.Type); + if (!nodes.Contains(name) || + excludedNames.Contains(name) || + excludedNames.Contains(GetSimpleName(name))) + { + continue; + } + + declaredNodes.Add(name); + } + + return declaredNodes; + } + + private static bool IsKept(CSharpType type, HashSet roots, HashSet nodes) + { + var providerName = GetProviderTypeName(type); + if (roots.Contains(providerName) && nodes.Contains(providerName)) + { + return true; + } + + if (!roots.Contains(type.Name)) + { + return false; + } + + var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); + return simpleNameLookup.TryGetValue(type.Name, out var matches) && + matches.Length == 1 && + string.Equals(matches[0], providerName, StringComparison.Ordinal); + } + + private static bool IsClientProviderRoot(TypeProvider provider, bool publicOnly) => + provider.IsClientProvider && + (!publicOnly || !HasApiBaselineDirectory() && provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + + private static bool IsAdditionalRootProvider(TypeProvider provider, HashSet roots, HashSet nodes) + { + if (provider.DeclaringTypeProvider != null || !IsKept(provider.Type, roots, nodes)) + { + return false; + } + + return provider is not ModelProvider && provider is not EnumProvider; + } + + private static bool HasApiBaselineDirectory() + { + var projectDirectory = CodeModelGenerator.Instance.Configuration.ProjectDirectory; + return !string.IsNullOrEmpty(projectDirectory) && + Directory.Exists(Path.GetFullPath(Path.Combine(projectDirectory, "..", "api"))); + } + + private static bool IsModelFactoryProvider(TypeProvider provider) + => provider is ModelFactoryProvider; + + private static HashSet GetHelperRootNames( + IReadOnlyList providers, + HashSet nodes, + HashSet reachableTypes, + IReadOnlyDictionary>? references = null) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + var isModelFactory = IsModelFactoryProvider(provider); + if (!reachableTypes.Contains(providerName) && !isModelFactory) + { + continue; + } + + AddHelperDependencies(roots, provider.HelperDependencyTypes, nodes, references == null ? null : references[providerName]); + + foreach (var property in provider.Properties) + { + AddInitializationHelperRoot(roots, property.Type, nodes); + AddParameterValidationHelperRoot(roots, property.AsParameter, nodes); + } + + foreach (var field in provider.Fields) + { + AddParameterValidationHelperRoot(roots, field.AsParameter, nodes); + } + + foreach (var constructor in provider.Constructors) + { + foreach (var parameter in constructor.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + } + } + + foreach (var method in provider.Methods) + { + // Only factory methods for reachable models can instantiate collection helpers. + if (isModelFactory && + (method.Signature.ReturnType == null || !reachableTypes.Contains(GetProviderTypeName(method.Signature.ReturnType)))) + { + continue; + } + + foreach (var parameter in method.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + if (isModelFactory) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, parameter.Type, nodes); + } + } + } + } + + return roots; + } + + private static void AddParameterValidationHelperRoot(HashSet roots, ParameterProvider parameter, HashSet nodes) + { + if (parameter.Validation != ParameterValidationType.None) + { + AddMatchingName(roots, "Argument", nodes); + } + } + + private static void AddHelperDependencies( + HashSet roots, + IReadOnlyList dependencies, + HashSet nodes, + HashSet? referencedNames) + { + foreach (var dependency in dependencies) + { + if (referencedNames == null) + { + AddTypeReference(roots, dependency, nodes); + continue; + } + + var matches = new HashSet(StringComparer.Ordinal); + AddTypeReference(matches, dependency, nodes); + foreach (var match in matches) + { + if (referencedNames.Contains(match)) + { + roots.Add(match); + } + } + } + } + + private static void RemoveUnusedRequestHeaderExtensionsRoot( + HashSet roots, + IReadOnlyDictionary> references, + IReadOnlyList providers) + { + var hasCustomReference = HasCustomRequestHeaderExtensionsReference(providers); + if (hasCustomReference) + { + return; + } + + var unusedRequestHeaderExtensions = new List(); + foreach (var root in roots) + { + if (IsRequestHeadersExtensionsRoot(root) && + !HasExternalReference(root, references)) + { + unusedRequestHeaderExtensions.Add(root); + } + } + + roots.ExceptWith(unusedRequestHeaderExtensions); + } + + private static bool HasExternalReference(string root, IReadOnlyDictionary> references) + { + foreach (var (source, sourceReferences) in references) + { + if (!string.Equals(source, root, StringComparison.Ordinal) && + sourceReferences.Contains(root)) + { + return true; + } + } + + return false; + } + + private static bool IsRequestHeadersExtensionsRoot(string root) => + root.EndsWith(".RequestHeaderExtensions", StringComparison.Ordinal) || + root.EndsWith(".RequestHeadersExtensions", StringComparison.Ordinal); + + private static bool HasCustomRequestHeaderExtensionsReference(IReadOnlyList providers) + { + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + if (customCodeView is NamedTypeSymbolProvider) + { + if (HasRequestHeaderExtensionsDependency(customCodeView.HelperDependencyTypes) || + HasRequestHeaderExtensionsDependency(customCodeView.BodyDependencyTypes) || + HasRequestHeaderExtensionsDependency(customCodeView.SignatureDependencyTypes)) + { + return true; + } + + continue; + } + + if (HasRequestHeaderExtensionsDependency(customCodeView.HelperDependencyTypes) || + HasRequestHeaderExtensionsDependency(customCodeView.BodyDependencyTypes) || + HasRequestHeaderExtensionsMethodDependency(customCodeView.Methods) || + HasRequestHeaderExtensionsPropertyDependency(customCodeView.Properties) || + HasRequestHeaderExtensionsFieldDependency(customCodeView.Fields)) + { + return true; + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsDependency(IEnumerable dependencies) + { + foreach (var dependency in dependencies) + { + if (IsRequestHeaderExtensionsDependency(dependency)) + { + return true; + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsMethodDependency(IReadOnlyList methods) + { + foreach (var method in methods) + { + if (IsRequestHeaderExtensionsDependency(method.Signature.ReturnType)) + { + return true; + } + + foreach (var parameter in method.Signature.Parameters) + { + if (IsRequestHeaderExtensionsDependency(parameter.Type)) + { + return true; + } + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsPropertyDependency(IReadOnlyList properties) + { + foreach (var property in properties) + { + if (IsRequestHeaderExtensionsDependency(property.Type)) + { + return true; + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsFieldDependency(IReadOnlyList fields) + { + foreach (var field in fields) + { + if (IsRequestHeaderExtensionsDependency(field.Type)) + { + return true; + } + } + + return false; + } + + private static bool IsRequestHeaderExtensionsDependency(string name) + => string.Equals(name, "RequestHeaderExtensions", StringComparison.Ordinal) || + string.Equals(name, "SetDelimited", StringComparison.Ordinal); + + private static bool IsRequestHeaderExtensionsDependency(CSharpType? type) + { + if (type == null) + { + return false; + } + + if (IsRequestHeaderExtensionsDependency(type.Name)) + { + return true; + } + + foreach (var argument in type.Arguments) + { + if (IsRequestHeaderExtensionsDependency(argument)) + { + return true; + } + } + + return false; + } + + private static bool IsSerializationProvider(TypeProvider provider) + { + var relativePath = provider.RelativeFilePath.Replace('\\', '/'); + return relativePath.EndsWith(".Serialization.cs", StringComparison.Ordinal) || + relativePath.EndsWith(".Serialization.Multipart.cs", StringComparison.Ordinal); + } + + private static void AddInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + var initializationType = type.PropertyInitializationType; + if (!string.Equals(initializationType.FullyQualifiedName, type.FullyQualifiedName, StringComparison.Ordinal)) + { + AddMatchingName(roots, initializationType.Name, nodes); + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.ListInitializationType, nodes); + } + + if (type.IsDictionary) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType, nodes); + } + + foreach (var argument in type.Arguments) + { + AddInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddModelFactoryCollectionInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.ListInitializationType, nodes); + } + + if (type.IsDictionary) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType, nodes); + } + + foreach (var argument in type.Arguments) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddMatchingName(HashSet target, string name, HashSet nodes) + { + if (nodes.Contains(name)) + { + target.Add(name); + return; + } + + var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); + if (!simpleNameLookup.TryGetValue(name, out var matches)) + { + return; + } + + foreach (var match in matches) + { + target.Add(match); + } + } + + private static void AddMatchingNamesWithSimpleNameSuffix(HashSet target, string suffix, HashSet nodes) + { + foreach (var node in nodes) + { + if (GetSimpleName(node).EndsWith(suffix, StringComparison.Ordinal)) + { + target.Add(node); + } + } + } + + private static Dictionary BuildSimpleNameLookup(HashSet nodes) + { + var lookup = new Dictionary>(StringComparer.Ordinal); + foreach (var node in nodes) + { + var simpleName = StripGenericArity(GetSimpleName(node)); + if (!lookup.TryGetValue(simpleName, out var matchingNodes)) + { + matchingNodes = []; + lookup.Add(simpleName, matchingNodes); + } + + matchingNodes.Add(node); + } + + var result = new Dictionary(StringComparer.Ordinal); + foreach (var (simpleName, matchingNodes) in lookup) + { + result.Add(simpleName, [.. matchingNodes]); + } + + return result; + } + + private static HashSet GetReachableTypes(HashSet roots, IReadOnlyDictionary> references) + { + return GetReachableTypes(roots, references, expandableNodes: null); + } + + private static HashSet GetReachableTypes( + HashSet roots, + IReadOnlyDictionary> references, + HashSet? expandableNodes) + { + var reachable = new HashSet(StringComparer.Ordinal); + var queue = new Queue(roots); + while (queue.Count > 0) + { + var current = queue.Dequeue(); + if (!reachable.Add(current)) + { + continue; + } + + if (expandableNodes != null && !expandableNodes.Contains(current)) + { + continue; + } + + if (!references.TryGetValue(current, out var children)) + { + continue; + } + + foreach (var child in children) + { + queue.Enqueue(child); + } + } + + return reachable; + } + + private static bool HasPublicApiPredecessor( + string name, + IReadOnlyDictionary> references, + HashSet publicizeReachable, + HashSet generatedImplementationInternalDeclarations) + { + foreach (var (owner, children) in references) + { + if (!publicizeReachable.Contains(owner) || + string.Equals(owner, name, StringComparison.Ordinal) || + generatedImplementationInternalDeclarations.Contains(owner) || + !children.Contains(name)) + { + continue; + } + + return true; + } + + return false; + } + + private static void AddSignatureReferences( + HashSet references, + MethodSignatureBase signature, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType, + bool includeAttributes = true, + bool includeAttributeArguments = true) + { + AddTypeReference(references, signature.ReturnType, nodes, serializationProviderNamesByType); + if (includeAttributes) + { + AddAttributes(references, signature.Attributes, nodes, serializationProviderNamesByType, includeAttributeArguments); + } + + foreach (var parameter in signature.Parameters) + { + AddTypeReference(references, parameter.Type, nodes, serializationProviderNamesByType); + if (includeAttributes) + { + AddAttributes(references, parameter.Attributes, nodes, serializationProviderNamesByType, includeAttributeArguments); + } + } + + if (signature is MethodSignature methodSignature) + { + AddTypeReference(references, methodSignature.ExplicitInterface, nodes, serializationProviderNamesByType); + if (methodSignature.GenericArguments != null) + { + foreach (var genericArgument in methodSignature.GenericArguments) + { + AddTypeReference(references, genericArgument, nodes, serializationProviderNamesByType); + } + } + + if (methodSignature.GenericParameterConstraints != null) + { + foreach (var constraint in methodSignature.GenericParameterConstraints) + { + AddTypeReference(references, constraint.Type, nodes, serializationProviderNamesByType); + } + } + } + + if (signature is ConstructorSignature constructorSignature) + { + AddTypeReference(references, constructorSignature.Type, nodes, serializationProviderNamesByType); + } + } + + private static void AddAttributes( + HashSet references, + IReadOnlyList attributes, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType, + bool includeArguments) + { + foreach (var attribute in attributes) + { + AddTypeReference(references, attribute.Type, nodes, serializationProviderNamesByType); + if (!includeArguments) + { + continue; + } + + foreach (var argument in attribute.Arguments) + { + AddAttributeArgumentReference(references, argument, nodes, serializationProviderNamesByType); + } + + foreach (var (_, argument) in attribute.PositionalArguments) + { + AddAttributeArgumentReference(references, argument, nodes, serializationProviderNamesByType); + } + } + } + + private static bool IsAttributeNamed(AttributeStatement attribute, string name) + => string.Equals(attribute.Type.Name, name, StringComparison.Ordinal) || + string.Equals(attribute.Type.Name, $"{name}Attribute", StringComparison.Ordinal); + + private static void AddAttributeArgumentReference( + HashSet references, + ValueExpression argument, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType) + { + if (argument is TypeOfExpression typeOf) + { + AddTypeReference(references, typeOf.Type, nodes, serializationProviderNamesByType); + } + } + + private static void AddTypeReference( + HashSet references, + CSharpType? type, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType = null) + { + if (type == null) + { + return; + } + + if (type.IsArray) + { + AddTypeReference(references, type.ElementType, nodes, serializationProviderNamesByType); + return; + } + + var providerTypeName = GetProviderTypeName(type); + if (nodes.Contains(providerTypeName)) + { + references.Add(providerTypeName); + if (serializationProviderNamesByType != null && serializationProviderNamesByType.TryGetValue(providerTypeName, out var serializationProviderNames)) + { + foreach (var serializationProviderName in serializationProviderNames) + { + references.Add(serializationProviderName); + } + } + } + + AddTypeReference(references, type.BaseType, nodes, serializationProviderNamesByType); + AddTypeReference(references, type.DeclaringType, nodes, serializationProviderNamesByType); + foreach (var argument in type.Arguments) + { + AddTypeReference(references, argument, nodes, serializationProviderNamesByType); + } + } + + private static string GetSimpleName(string fullyQualifiedName) + { + var lastDot = fullyQualifiedName.LastIndexOf('.'); + return lastDot < 0 ? fullyQualifiedName : fullyQualifiedName.Substring(lastDot + 1); + } + + private static string? GetNamespaceName(string fullyQualifiedName) + { + var lastDot = fullyQualifiedName.LastIndexOf('.'); + return lastDot < 0 ? null : fullyQualifiedName.Substring(0, lastDot); + } + + private static string GetProviderTypeName(CSharpType type) + { + var name = type.Arguments.Count > 0 && !type.Name.Contains('`', StringComparison.Ordinal) + ? $"{type.Name}`{type.Arguments.Count}" + : type.Name; + return string.IsNullOrEmpty(type.Namespace) ? name : $"{type.Namespace}.{name}"; + } + + private static string StripGenericArity(string name) + { + var tick = name.IndexOf('`'); + return tick < 0 ? name : name.Substring(0, tick); + } + + private sealed record ProviderReferenceGraph( + HashSet Nodes, + Dictionary> References); + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs new file mode 100644 index 00000000000..eafe1d9d546 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; + +namespace Microsoft.TypeSpec.Generator +{ + internal sealed record ProviderReferenceMapResult( + HashSet InternalizeCandidates, + HashSet PublicizeCandidates, + HashSet RemoveCandidates) + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/TypeProviderWriter.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/TypeProviderWriter.cs index 49fe9723973..eb07aa4519f 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/TypeProviderWriter.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/TypeProviderWriter.cs @@ -45,7 +45,7 @@ private bool IsPublicContext(TypeProvider provider) private void WriteType(CodeWriter writer) { - if (IsPublicContext(_provider)) + if (_provider.PreserveTypeXmlDocs || IsPublicContext(_provider)) { writer.WriteXmlDocsNoScope(_provider.XmlDocs); } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/NamedTypeSymbolProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/NamedTypeSymbolProvider.cs index ed3a45d54e8..26cc96af377 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/NamedTypeSymbolProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/NamedTypeSymbolProvider.cs @@ -22,6 +22,7 @@ internal sealed class NamedTypeSymbolProvider : TypeProvider { private INamedTypeSymbol _namedTypeSymbol; private readonly Compilation _compilation; + private string? _metadataName; private TypeProvider? _baseTypeProvider; public NamedTypeSymbolProvider(INamedTypeSymbol namedTypeSymbol, Compilation compilation) @@ -30,6 +31,23 @@ public NamedTypeSymbolProvider(INamedTypeSymbol namedTypeSymbol, Compilation com _compilation = compilation; } + internal string MetadataName + { + get + { + if (_metadataName != null) + { + return _metadataName; + } + + var ns = _namedTypeSymbol.ContainingNamespace.GetFullyQualifiedNameFromDisplayString(); + _metadataName = string.IsNullOrEmpty(ns) ? _namedTypeSymbol.Name : $"{ns}.{_namedTypeSymbol.Name}"; + return _metadataName; + } + } + + internal string MetadataSimpleName => _namedTypeSymbol.Name; + private protected sealed override NamedTypeSymbolProvider? BuildCustomCodeView(string? generatedTypeName = default, string? generatedTypeNamespace = default) => null; private protected sealed override TypeProvider? BuildLastContractView(string? generatedTypeName = default, string? generatedTypeNamespace = default) => null; @@ -321,6 +339,165 @@ [.. methodSymbol.Parameters.Select(p => ConvertToParameterProvider(methodSymbol, return [.. methods]; } + protected internal override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new HashSet(); + foreach (var syntaxReference in _namedTypeSymbol.DeclaringSyntaxReferences) + { + AddBodyDependencyTypes(syntaxReference.GetSyntax(), dependencies); + } + + return [.. dependencies]; + } + + protected internal override IReadOnlyList BuildSignatureDependencyTypes() + { + var dependencies = new HashSet(); + foreach (var syntaxReference in _namedTypeSymbol.DeclaringSyntaxReferences) + { + if (syntaxReference.GetSyntax() is not TypeDeclarationSyntax typeDeclaration || + !IsPublic(typeDeclaration.Modifiers)) + { + continue; + } + + AddSyntaxTypeReferences(typeDeclaration.BaseList, dependencies); + foreach (var member in typeDeclaration.Members) + { + if (IsPublicApiMember(member)) + { + AddPublicSignatureDependencyTypes(member, dependencies); + } + } + } + + return [.. dependencies]; + } + + private void AddBodyDependencyTypes(SyntaxNode syntax, HashSet dependencies) + { + AddSyntaxTypeReferences(syntax, dependencies); + + foreach (var invocation in syntax.DescendantNodes().OfType()) + { + if (GetInvocationName(invocation) == "SetDelimited") + { + dependencies.Add(CreateUnresolvedDependencyType("SetDelimited")); + } + } + } + + private static void AddPublicSignatureDependencyTypes(MemberDeclarationSyntax member, HashSet dependencies) + { + switch (member) + { + case MethodDeclarationSyntax method: + AddSyntaxTypeReferences(method.ReturnType, dependencies); + AddSyntaxTypeReferences(method.ParameterList, dependencies); + AddSyntaxTypeReferences(method.ConstraintClauses, dependencies); + break; + case ConstructorDeclarationSyntax constructor: + AddSyntaxTypeReferences(constructor.ParameterList, dependencies); + break; + case ConversionOperatorDeclarationSyntax conversion: + AddSyntaxTypeReferences(conversion.Type, dependencies); + AddSyntaxTypeReferences(conversion.ParameterList, dependencies); + break; + case OperatorDeclarationSyntax @operator: + AddSyntaxTypeReferences(@operator.ReturnType, dependencies); + AddSyntaxTypeReferences(@operator.ParameterList, dependencies); + break; + case PropertyDeclarationSyntax property: + AddSyntaxTypeReferences(property.Type, dependencies); + break; + case IndexerDeclarationSyntax indexer: + AddSyntaxTypeReferences(indexer.Type, dependencies); + AddSyntaxTypeReferences(indexer.ParameterList, dependencies); + break; + case FieldDeclarationSyntax field: + AddSyntaxTypeReferences(field.Declaration.Type, dependencies); + break; + case EventFieldDeclarationSyntax eventField: + AddSyntaxTypeReferences(eventField.Declaration.Type, dependencies); + break; + case EventDeclarationSyntax @event: + AddSyntaxTypeReferences(@event.Type, dependencies); + break; + case DelegateDeclarationSyntax @delegate: + AddSyntaxTypeReferences(@delegate.ReturnType, dependencies); + AddSyntaxTypeReferences(@delegate.ParameterList, dependencies); + AddSyntaxTypeReferences(@delegate.ConstraintClauses, dependencies); + break; + case BaseTypeDeclarationSyntax type: + AddSyntaxTypeReferences(type.BaseList, dependencies); + break; + } + } + + private static void AddSyntaxTypeReferences(SyntaxNode? node, HashSet dependencies) + { + if (node == null) + { + return; + } + + foreach (var name in node.DescendantNodesAndSelf().OfType()) + { + dependencies.Add(CreateUnresolvedDependencyType(name.Identifier.ValueText)); + } + + foreach (var name in node.DescendantNodesAndSelf().OfType()) + { + dependencies.Add(CreateUnresolvedDependencyType(name.Identifier.ValueText)); + } + } + + private static void AddSyntaxTypeReferences(IEnumerable nodes, HashSet dependencies) + { + foreach (var node in nodes) + { + AddSyntaxTypeReferences(node, dependencies); + } + } + + private static bool IsPublicApiMember(MemberDeclarationSyntax member) + => member switch + { + EventDeclarationSyntax @event => IsPublic(@event.Modifiers), + EventFieldDeclarationSyntax @event => IsPublic(@event.Modifiers), + BaseFieldDeclarationSyntax field => IsPublic(field.Modifiers), + BaseMethodDeclarationSyntax method => IsPublic(method.Modifiers), + BasePropertyDeclarationSyntax property => IsPublic(property.Modifiers), + DelegateDeclarationSyntax @delegate => IsPublic(@delegate.Modifiers), + BaseTypeDeclarationSyntax type => IsPublic(type.Modifiers), + _ => false + }; + + private static bool IsPublic(SyntaxTokenList modifiers) + => modifiers.Any(static modifier => + modifier.IsKind(SyntaxKind.PublicKeyword) || + modifier.IsKind(SyntaxKind.ProtectedKeyword)); + + private static CSharpType CreateUnresolvedDependencyType(string name) + => new( + name, + string.Empty, + isValueType: false, + isNullable: false, + declaringType: null, + args: [], + isPublic: false, + isStruct: false); + + private static string? GetInvocationName(InvocationExpressionSyntax invocation) + => invocation.Expression switch + { + IdentifierNameSyntax identifier => identifier.Identifier.ValueText, + MemberAccessExpressionSyntax memberAccess => memberAccess.Name.Identifier.ValueText, + GenericNameSyntax genericName => genericName.Identifier.ValueText, + _ => null + }; + private static bool IsPartialMethodDeclaration(IMethodSymbol methodSymbol) { foreach (var syntaxReference in methodSymbol.DeclaringSyntaxReferences) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs index 3d71670d5a9..bbaf3e54d24 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs @@ -143,6 +143,13 @@ public XmlDocProvider XmlDocs private set => _xmlDocs = value; } + internal bool PreserveTypeXmlDocs { get; private set; } + + internal void PreserveXmlDocs() + { + PreserveTypeXmlDocs = true; + } + public string? Deprecated { get => _deprecated; @@ -292,6 +299,22 @@ private IReadOnlyList ApplyCustomizationFilter(IEnumerable SerializationProviders => _serializationProviders ??= BuildSerializationProviders(); + private IReadOnlyList? _helperDependencyTypes; + internal IReadOnlyList HelperDependencyTypes => _helperDependencyTypes ??= BuildHelperDependencyTypes(); + protected internal virtual IReadOnlyList BuildHelperDependencyTypes() => []; + + private IReadOnlyList? _bodyDependencyTypes; + internal IReadOnlyList BodyDependencyTypes => _bodyDependencyTypes ??= BuildBodyDependencyTypes(); + protected internal virtual IReadOnlyList BuildBodyDependencyTypes() => []; + + private IReadOnlyList? _signatureDependencyTypes; + internal IReadOnlyList SignatureDependencyTypes => _signatureDependencyTypes ??= BuildSignatureDependencyTypes(); + protected internal virtual IReadOnlyList BuildSignatureDependencyTypes() => []; + + protected internal virtual bool IsClientProvider => false; + + protected internal virtual bool IncludeGeneratedBodyReferences => false; + private IReadOnlyList? _attributes; public IReadOnlyList Attributes @@ -538,6 +561,7 @@ public virtual void Reset() _serializationProviders = null; _nestedTypes = null; _xmlDocs = null; + PreserveTypeXmlDocs = false; _declarationModifiers = null; _relativeFilePath = null; _customCodeView = new(() => BuildCustomCodeView()); @@ -741,75 +765,10 @@ internal void ProcessTypeForBackCompatibility() { _enumValues = updatedEnumValues; } - - // Back-compatibility processing intentionally runs after the library visitor pass so - // that the contract comparison uses the final, post-visitor member signatures (otherwise - // we could incorrectly decide whether a back-compat member is needed). As a result, any - // members synthesized above (e.g. back-compat overloads) have not been visited yet. Run - // only those newly-added members through the visitors now so visitor transforms apply to - // them as well, without re-visiting members that were already visited during the main pass. - if (newMethods != null) - { - newMethods = VisitNewMembers(newMethods, Methods, static (member, visitor) => member.Accept(visitor)); - } - if (newConstructors != null) - { - newConstructors = VisitNewMembers(newConstructors, Constructors, static (member, visitor) => visitor.VisitConstructor(member)); - } - if (newFields != null) - { - newFields = VisitNewMembers(newFields, Fields, static (member, visitor) => visitor.VisitField(member)); - } - Update(fields: newFields, methods: newMethods, constructors: newConstructors); } } - // Runs newly-added back-compatibility members through every registered visitor while leaving - // members that were already visited during the main visitor pass untouched. Membership in the - // already-visited set is determined by reference identity against the pre-Update collection. - private static IReadOnlyList VisitNewMembers( - IEnumerable allMembers, - IReadOnlyList alreadyVisited, - Func visit) - where T : class - { - var visitors = CodeModelGenerator.Instance.Visitors; - var materialized = allMembers as IReadOnlyList ?? [.. allMembers]; - if (visitors.Count == 0) - { - return materialized; - } - - var alreadyVisitedSet = new HashSet(alreadyVisited, ReferenceEqualityComparer.Instance); - var result = new List(materialized.Count); - foreach (var member in materialized) - { - if (alreadyVisitedSet.Contains(member)) - { - result.Add(member); - continue; - } - - T? visited = member; - foreach (var visitor in visitors) - { - visited = visit(visited, visitor); - if (visited == null) - { - break; - } - } - - if (visited != null) - { - result.Add(visited); - } - } - - return result; - } - protected internal virtual IReadOnlyList? BuildEnumValuesForBackCompatibility(IReadOnlyList originalEnumValues) => null; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/SourceInput/SourceInputModel.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/SourceInput/SourceInputModel.cs index a329166ee4b..4792e781338 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/SourceInput/SourceInputModel.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/SourceInput/SourceInputModel.cs @@ -22,6 +22,7 @@ public class SourceInputModel public ApiCompatBaseline ApiCompatBaseline { get; } private readonly Lazy> _nameMap; + private readonly Lazy> _customizationTypeProviders; public SourceInputModel(Compilation? customization, Compilation? lastContract) : this(customization, lastContract, ApiCompatBaseline.Empty) @@ -35,6 +36,7 @@ public SourceInputModel(Compilation? customization, Compilation? lastContract, A ApiCompatBaseline = apiCompatBaseline ?? ApiCompatBaseline.Empty; _nameMap = new(PopulateNameMap); + _customizationTypeProviders = new(PopulateCustomizationTypeProviders); } private IReadOnlyDictionary PopulateNameMap() @@ -70,6 +72,30 @@ private IReadOnlyDictionary PopulateNameMap() return FindTypeInCompilation(LastContract, ns, name, true, declaringTypeName, includeInternal: false); } + private IReadOnlyList PopulateCustomizationTypeProviders() + { + var providers = new List(); + if (Customization == null) + { + return providers; + } + + foreach (IModuleSymbol module in Customization.Assembly.Modules) + { + foreach (var type in SourceInputHelper.GetSymbols(module.GlobalNamespace)) + { + if (type is INamedTypeSymbol namedTypeSymbol) + { + providers.Add(new NamedTypeSymbolProvider(namedTypeSymbol, Customization)); + } + } + } + + return providers; + } + + internal IReadOnlyList GetCustomizationTypeProviders() => _customizationTypeProviders.Value; + private TypeProvider? FindTypeInCompilation( Compilation? compilation, string ns, diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/TypeFactory.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/TypeFactory.cs index cd70e60ea9a..f53cdcc4d55 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/TypeFactory.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/TypeFactory.cs @@ -21,6 +21,9 @@ public class TypeFactory private ChangeTrackingDictionaryDefinition ChangeTrackingDictionaryProvider { get; } = new(); + private OptionalDefinition? _optionalProvider; + private OptionalDefinition OptionalProvider => _optionalProvider ??= new(); + private Dictionary InputTypeToModelProvider { get; } = []; public IDictionary CSharpTypeMap { get; } = new Dictionary(CSharpType.IgnoreNullableComparer); @@ -200,11 +203,6 @@ protected internal TypeFactory() if (modelProvider != null) { - if (model.Access == "public") - { - CodeModelGenerator.Instance.AddTypeToKeep(modelProvider); - } - CSharpTypeMap[modelProvider.Type] = modelProvider; TypeProvidersByName[modelProvider.Type.Name] = modelProvider; } @@ -500,6 +498,11 @@ inputProperty.Type is InputArrayType && /// public virtual CSharpType DictionaryInitializationType => ChangeTrackingDictionaryProvider.Type; + /// + /// The type used to represent optional values in generated helper code. + /// + public virtual CSharpType OptionalType => OptionalProvider.Type; + /// /// Returns the serialization type providers for the given model type provider. /// diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/OutputLibraryVisitorTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/OutputLibraryVisitorTests.cs index f716aad5ab3..240dd77a792 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/OutputLibraryVisitorTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/OutputLibraryVisitorTests.cs @@ -451,7 +451,7 @@ private class TestFilterVisitor : LibraryVisitor return method; } - protected internal override ConstructorProvider? VisitConstructor(ConstructorProvider constructor) + protected override ConstructorProvider? VisitConstructor(ConstructorProvider constructor) { if (constructor.Signature.Parameters.Count > 0) { @@ -469,7 +469,7 @@ private class TestFilterVisitor : LibraryVisitor return property; } - protected internal override FieldProvider? VisitField(FieldProvider field) + protected override FieldProvider? VisitField(FieldProvider field) { if (field.Name == "TestField") { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs index 28981148a4d..8e8c10dca8d 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs @@ -60,7 +60,42 @@ public async Task RemovesInvalidUsings() CollectionAssert.Contains(usings, "System"); } + [Test] + public async Task RemovesInvalidUsingsKeepsFileHeader() + { + MockHelpers.LoadMockGenerator(); + var workspace = new AdhocWorkspace(); + var projectInfo = ProjectInfo.Create( + ProjectId.CreateNewId(), + VersionStamp.Create(), + name: "TestProj", + assemblyName: "TestProj", + language: LanguageNames.CSharp) + .WithMetadataReferences(new[] + { + MetadataReference.CreateFromFile(typeof(object).Assembly.Location) + }); + + var project = workspace.AddProject(projectInfo); + var folder = Helpers.GetAssetFileOrDirectoryPath(false); + project = AddGeneratedDocument( + project, + "RootClient.cs", + "src", + "Generated", + "RootClient.cs", + File.ReadAllText(Path.Join(folder, "RootClient.cs"))); + var postProcessor = new TestPostProcessor("RootClient.cs"); + + var resultProject = await postProcessor.RemoveAsync(project); + var doc = resultProject.Documents.Single(d => d.Name == "RootClient.cs"); + var text = (await doc.GetTextAsync()).ToString(); + StringAssert.StartsWith("// Copyright (c) Microsoft Corporation. All rights reserved.", text); + StringAssert.Contains("// ", text); + StringAssert.Contains("#nullable disable", text); + StringAssert.DoesNotContain("using Missing.Namespace;", text); + } [Test] public async Task DoesNotRemoveValidUsings() { @@ -289,11 +324,14 @@ public async Task DoesNotRemoveValidAttributes() Assert.AreEqual(Helpers.GetExpectedFromFile().TrimEnd(), output, "The output should match the expected content."); } + private static Project AddGeneratedDocument(Project project, string name, string folder1, string folder2, string fileName, string text) + => project.AddDocument(name, text, folders: [folder1, folder2], filePath: Path.Join(folder1, folder2, fileName)).Project; + private class TestPostProcessor : PostProcessor { private readonly string _rootFile; - public TestPostProcessor(string rootFile, IEnumerable? nonRootTypes = null) : base([], additionalNonRootTypeNames: nonRootTypes) + public TestPostProcessor(string rootFile, IEnumerable? additionalRootTypeNames = null, IEnumerable? nonRootTypes = null, string? modelFactoryFullName = null) : base((additionalRootTypeNames ?? []).ToHashSet(), modelFactoryFullName: modelFactoryFullName, additionalNonRootTypeNames: nonRootTypes) { _rootFile = rootFile; } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/TestData/PostProcessorTests/RemovesInvalidUsingsKeepsFileHeader/RootClient.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/TestData/PostProcessorTests/RemovesInvalidUsingsKeepsFileHeader/RootClient.cs new file mode 100644 index 00000000000..572d6f590ac --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/TestData/PostProcessorTests/RemovesInvalidUsingsKeepsFileHeader/RootClient.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// + +#nullable disable + +using Missing.Namespace; + +namespace Sample +{ + public partial class RootClient + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs index 1bdf4020167..c35f1968d76 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs @@ -177,6 +177,32 @@ public async Task OmitsModelFactoryMethodIfParamTypeInternal() Assert.IsNull(modelFactory); } + // This test validates that a derived model customized to be internal does not get a + // public model factory method just because its base model remains public. + [Test] + public async Task OmitsModelFactoryMethodIfDerivedModelTypeInternal() + { + var baseModel = InputFactory.Model( + "baseModel", + properties: [InputFactory.Property("BaseProp", InputPrimitiveType.String)]); + var derivedModel = InputFactory.Model( + "derivedModel", + properties: [InputFactory.Property("DerivedProp", InputPrimitiveType.String)], + baseModel: baseModel); + + var mockGenerator = await MockHelpers.LoadMockGeneratorAsync( + inputModelTypes: [baseModel, derivedModel], + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + var csharpGen = new CSharpGen(); + + await csharpGen.ExecuteAsync(); + + var modelFactory = mockGenerator.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider); + Assert.IsNotNull(modelFactory); + CollectionAssert.Contains(modelFactory!.Methods.Select(m => m.Signature.Name), "BaseModel"); + CollectionAssert.DoesNotContain(modelFactory.Methods.Select(m => m.Signature.Name), "DerivedModel"); + } + [TestCase(true)] [TestCase(false)] public async Task CanCustomizeModelFullConstructor(bool extraParameters) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoryProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoryProviderTests.cs index 934239ba78f..45f9b0385b4 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoryProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoryProviderTests.cs @@ -975,61 +975,6 @@ public async Task BackCompatibility_BackCompatMethodAlreadyCustom() Assert.AreEqual(4, publicModel1Methods[0].Signature.Parameters.Count); } - // Back-compat members are synthesized in ProcessTypeForBackCompatibility, which runs after the - // main library visitor pass. This test ensures those newly-added members are still run through - // the registered visitors (only the new members, not the already-visited existing ones). - [Test] - public async Task BackCompatibility_BackCompatMethodIsVisited() - { - _instance = (await MockHelpers.LoadMockGeneratorAsync( - inputNamespaceName: "Sample.Namespace", - inputModelTypes: ModelList, - lastContractCompilation: async () => await Helpers.GetCompilationFromDirectoryAsync(method: "BackCompatibility_NewModelPropertyAdded"))).Object; - - var recordingVisitor = new RecordingMethodVisitor(); - _instance.AddVisitor(recordingVisitor); - - var modelFactory = _instance!.OutputLibrary.ModelFactory.Value; - modelFactory.ProcessTypeForBackCompatibility(); - - var backCompatMethod = modelFactory.Methods - .FirstOrDefault(m => m.Signature.Name == "PublicModel1" && m.Signature.Parameters.All(p => p.Name != "dictProp")); - Assert.IsNotNull(backCompatMethod, "Expected a back-compat overload to be synthesized."); - - // The synthesized back-compat method must have been visited. - Assert.IsTrue( - recordingVisitor.VisitedMethods.Contains(backCompatMethod!), - "The back-compat method was not visited by the library visitor."); - - // Existing methods that were already part of the contract are not re-visited by the visitor - // added after the main pass (they would have been visited during the main pass in a real run). - var currentOverloadMethod = modelFactory.Methods - .FirstOrDefault(m => m.Signature.Name == "PublicModel1" && m.Signature.Parameters.Any(p => p.Name == "dictProp")); - Assert.IsNotNull(currentOverloadMethod); - Assert.IsFalse(recordingVisitor.VisitedMethods.Contains(currentOverloadMethod!)); - } - - // Verifies that a visitor can mutate (rename) a synthesized back-compat method and the change is - // reflected in the final generated methods. - [Test] - public async Task BackCompatibility_BackCompatMethodCanBeMutatedByVisitor() - { - _instance = (await MockHelpers.LoadMockGeneratorAsync( - inputNamespaceName: "Sample.Namespace", - inputModelTypes: ModelList, - lastContractCompilation: async () => await Helpers.GetCompilationFromDirectoryAsync(method: "BackCompatibility_NewModelPropertyAdded"))).Object; - - _instance.AddVisitor(new BackCompatMethodRenamingVisitor()); - - var modelFactory = _instance!.OutputLibrary.ModelFactory.Value; - modelFactory.ProcessTypeForBackCompatibility(); - - // The visitor renames any method carrying the EditorBrowsableNever attribute (the back-compat - // overload) so the mutation must be observable on the final method collection. - var renamed = modelFactory.Methods.FirstOrDefault(m => m.Signature.Name == "PublicModel1Renamed"); - Assert.IsNotNull(renamed, "The visitor's rename of the back-compat method was not applied."); - } - private static InputModelType[] GetTestModels() { InputType additionalPropertiesUnknown = InputPrimitiveType.Any; @@ -1057,29 +1002,5 @@ private static InputModelType[] GetTestModels() InputFactory.Model("ModelWithUnknownAdditionalProperties", properties: properties, additionalProperties: additionalPropertiesUnknown), ]; } - - private class RecordingMethodVisitor : LibraryVisitor - { - public List VisitedMethods { get; } = []; - - protected internal override MethodProvider? VisitMethod(MethodProvider method) - { - VisitedMethods.Add(method); - return method; - } - } - - private class BackCompatMethodRenamingVisitor : LibraryVisitor - { - protected internal override MethodProvider? VisitMethod(MethodProvider method) - { - if (method.Signature.Name == "PublicModel1" - && method.Signature.Attributes.Any(a => a.ToDisplayString().Contains("EditorBrowsable"))) - { - method.Signature.Update(name: "PublicModel1Renamed"); - } - return method; - } - } } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs new file mode 100644 index 00000000000..bdb2034f5f0 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Sample.Models +{ + internal partial class DerivedModel + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ClientCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ClientCustomizationTests.cs index 7f3ea5cd1ff..d96e98a9d8e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ClientCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ClientCustomizationTests.cs @@ -405,7 +405,7 @@ private class ClientTypeProvider : TypeProvider public MethodProvider[] MethodProviders { get; set; } = []; public ConstructorProvider[] ConstructorProviders { get; set; } = []; - protected override string BuildRelativeFilePath() => "."; + protected override string BuildRelativeFilePath() => $"{Name}.cs"; protected override string BuildName() => "MockInputClient"; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs index 9f9945a2360..a30bd15aa29 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs @@ -1524,7 +1524,7 @@ await MockHelpers.LoadMockGeneratorAsync( } [Test] - public void PublicModelsAreIncludedInAdditionalRootTypes() + public void PublicModelsAreNotIncludedInAdditionalRootTypes() { var inputModel = InputFactory.Model( "MockInputModel", @@ -1537,7 +1537,7 @@ public void PublicModelsAreIncludedInAdditionalRootTypes() Assert.IsNotNull(modelProvider); var rootTypes = CodeModelGenerator.Instance.AdditionalRootTypes; - Assert.IsTrue(rootTypes.Contains("Sample.Models.MockInputModel")); + Assert.IsFalse(rootTypes.Contains("Sample.Models.MockInputModel")); } [Test] diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/documentation/src/Generated/DocumentationModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/documentation/src/Generated/DocumentationModelFactory.cs index 8918ac1946c..4440dc1b6a0 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/documentation/src/Generated/DocumentationModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/documentation/src/Generated/DocumentationModelFactory.cs @@ -8,7 +8,6 @@ namespace Documentation { public static partial class DocumentationModelFactory { - public static BulletPointsModel BulletPointsModel(BulletPointsEnum prop = default) => throw null; } } diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/basic/src/Generated/ParametersBasicModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/basic/src/Generated/ParametersBasicModelFactory.cs index 06d44d34bc1..c17dda5eaec 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/basic/src/Generated/ParametersBasicModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/basic/src/Generated/ParametersBasicModelFactory.cs @@ -3,7 +3,6 @@ #nullable disable using Parameters.Basic._ExplicitBody; -using Parameters.Basic._ImplicitBody; namespace Parameters.Basic { diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/spread/src/Generated/ParametersSpreadModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/spread/src/Generated/ParametersSpreadModelFactory.cs index 775c933bc6b..4e494c2f2ec 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/spread/src/Generated/ParametersSpreadModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/spread/src/Generated/ParametersSpreadModelFactory.cs @@ -2,8 +2,6 @@ #nullable disable -using System.Collections.Generic; -using Parameters.Spread._Alias; using Parameters.Spread._Model; namespace Parameters.Spread diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/multipart/src/Generated/PayloadMultiPartModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/multipart/src/Generated/PayloadMultiPartModelFactory.cs index 6d036b01c86..e5c87b994a6 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/multipart/src/Generated/PayloadMultiPartModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/multipart/src/Generated/PayloadMultiPartModelFactory.cs @@ -2,7 +2,6 @@ #nullable disable -using System; using System.ClientModel; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/pageable/src/Generated/PayloadPageableModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/pageable/src/Generated/PayloadPageableModelFactory.cs index 0b66e798a05..1b7b8eefa6f 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/pageable/src/Generated/PayloadPageableModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/pageable/src/Generated/PayloadPageableModelFactory.cs @@ -2,18 +2,12 @@ #nullable disable -using System; -using System.Collections.Generic; -using Payload.Pageable._PageSize; -using Payload.Pageable._ServerDrivenPagination; using Payload.Pageable._ServerDrivenPagination.AlternateInitialVerb; -using Payload.Pageable._ServerDrivenPagination.ContinuationToken; namespace Payload.Pageable { public static partial class PayloadPageableModelFactory { - public static Pet Pet(string id = default, string name = default) => throw null; public static XmlPet XmlPet(string id = default, string name = default) => throw null; diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/special-words/src/Generated/SpecialWordsModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/special-words/src/Generated/SpecialWordsModelFactory.cs index 1abb2114c9a..43aca8db259 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/special-words/src/Generated/SpecialWordsModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/special-words/src/Generated/SpecialWordsModelFactory.cs @@ -2,10 +2,8 @@ #nullable disable -using System.Collections.Generic; using SpecialWords._ModelProperties; using SpecialWords._Models; -using SpecialWords._ReservedOperationBodyParams; namespace SpecialWords { From d3e60effb2740365902b65622d702c32b250219a Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 2 Jul 2026 03:07:19 +0000 Subject: [PATCH 02/19] refactor(http-client-csharp): replace Roslyn reference map analysis Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/Providers/ClientProvider.cs | 77 + .../Providers/ClientUriBuilderDefinition.cs | 1 + .../Providers/CollectionResultDefinition.cs | 16 + .../MrwSerializationTypeDefinition.Xml.cs | 6 +- .../MrwSerializationTypeDefinition.cs | 71 +- ...ultipartFormDataSerializationDefinition.cs | 6 + .../src/Providers/RestClientProvider.cs | 150 +- .../SerializationFormatDefinition.cs | 2 + .../src/Snippets/HttpRequestApiSnippets.cs | 4 +- ...ClientBodyDependencyPostProcessingTests.cs | 734 ++++++ .../SystemObjectModelSerializationTests.cs | 34 - .../RestClientProviderTests.cs | 32 - .../test/TestHelpers/MockHelpers.cs | 17 +- .../test/TypeSpecInputConverterTests.cs | 36 - .../src/CSharpGen.cs | 41 +- .../src/LibraryVisitor.cs | 4 +- .../PostProcessing/GeneratedCodeWorkspace.cs | 12 +- .../src/PostProcessing/PostProcessor.cs | 177 +- .../ProviderReferenceMapAnalyzer.cs | 2218 +++++++++++++++++ .../ProviderReferenceMapResult.cs | 14 + .../src/Primitives/TypeProviderWriter.cs | 2 +- .../src/Providers/NamedTypeSymbolProvider.cs | 177 ++ .../src/Providers/TypeProvider.cs | 89 +- .../src/SourceInput/SourceInputModel.cs | 26 + .../src/TypeFactory.cs | 18 +- .../src/Utilities/TypeSymbolExtensions.cs | 2 +- .../test/OutputLibraryVisitorTests.cs | 4 +- .../test/PostProcessing/PostProcessorTests.cs | 40 +- .../RootClient.cs | 15 + .../ModelFactoriesCustomizationTests.cs | 26 + .../ModelFactoryProviderTests.cs | 79 - .../DerivedModel.cs | 9 + .../ClientCustomizationTests.cs | 2 +- .../ModelProviders/ModelProviderTests.cs | 4 +- .../GenericContainer.cs | 6 + .../Utilities/TypeSymbolExtensionsTests.cs | 13 + .../Generated/DocumentationModelFactory.cs | 1 - .../Generated/ParametersBasicModelFactory.cs | 1 - .../Generated/ParametersSpreadModelFactory.cs | 2 - .../Generated/PayloadMultiPartModelFactory.cs | 1 - .../Generated/PayloadPageableModelFactory.cs | 6 - .../src/Generated/SpecialWordsModelFactory.cs | 2 - 42 files changed, 3651 insertions(+), 526 deletions(-) create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/TestData/PostProcessorTests/RemovesInvalidUsingsKeepsFileHeader/RootClient.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TestData/TypeSymbolExtensionsTests/TypeParameterDoesNotResolveContainingGenericType/GenericContainer.cs diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs index f8ee8744e32..2c7186b2f25 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs @@ -43,6 +43,7 @@ private record ApiVersionFields(FieldProvider Field, PropertyProvider? Correspon private const string ClientSuffix = "Client"; private readonly FormattableString _publicCtorDescription; private readonly InputClient _inputClient; + protected override bool IsReferenceMapRoot => true; internal InputClient InputClient => _inputClient; private readonly InputAuth? _inputAuth; private readonly ParameterProvider _endpointParameter; @@ -426,6 +427,82 @@ private IReadOnlyList GetClientParameters() protected override string BuildName() => _inputClient.IsExactName ? _inputClient.Name : _inputClient.Name.ToIdentifierName(); + protected override IReadOnlyList BuildHelperDependencyTypes() + { + foreach (var method in Methods.OfType()) + { + if (!method.IsMethodSuppressed() && method.BodyStatements != null) + { + return [new CancellationTokenExtensionsDefinition().Type, new ClientPipelineExtensionsDefinition().Type]; + } + } + + return []; + } + + protected override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new List(); + foreach (var method in Methods.OfType()) + { + if (method.BodyStatements == null) + { + continue; + } + + if (method.CollectionDefinition != null) + { + dependencies.Add(method.CollectionDefinition.Type); + } + + if (method.ServiceMethod == null) + { + continue; + } + + AddInputTypeDependency(dependencies, method.ServiceMethod.Response.Type); + AddInputTypeDependency(dependencies, method.ServiceMethod.Exception?.Type); + foreach (var parameter in method.ServiceMethod.Parameters) + { + if (IsContentTypeParameter(parameter)) + { + continue; + } + + AddInputTypeDependency(dependencies, parameter.Type); + } + + foreach (var parameter in method.ServiceMethod.Operation.Parameters) + { + if (IsContentTypeParameter(parameter)) + { + continue; + } + + AddInputTypeDependency(dependencies, parameter.Type); + } + + // Operation responses are input metadata. The generated method signature and body + // dependencies above capture the response types that are actually used. + } + + return dependencies; + } + + private static bool IsContentTypeParameter(InputParameter parameter) => + parameter is InputHeaderParameter { IsContentType: true } || + parameter is InputMethodParameter { Location: InputRequestLocation.Header } && + string.Equals(parameter.SerializedName, "Content-Type", StringComparison.OrdinalIgnoreCase); + + private static void AddInputTypeDependency(List dependencies, InputType? inputType) + { + var type = inputType == null ? null : ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(inputType); + if (type != null) + { + dependencies.Add(type); + } + } + protected override FieldProvider[] BuildFields() { List fields = [EndpointField]; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientUriBuilderDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientUriBuilderDefinition.cs index 92e0cf23a67..4dc5c283cc7 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientUriBuilderDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientUriBuilderDefinition.cs @@ -28,6 +28,7 @@ internal sealed class ClientUriBuilderDefinition : TypeProvider private readonly FieldProvider _uriBuilderField; private readonly FieldProvider _pathAndQueryField; private readonly FieldProvider _pathLengthField; + protected override bool IncludeGeneratedBodyReferences => true; private PropertyProvider? _uriBuilderProperty; private PropertyProvider UriBuilderProperty => _uriBuilderProperty ??= new( diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs index ae617957bf5..590eaf2b935 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs @@ -217,6 +217,22 @@ private bool HasPagingOperationNameCollision(string operationName) protected override TypeSignatureModifiers BuildDeclarationModifiers() => TypeSignatureModifiers.Internal | TypeSignatureModifiers.Partial | TypeSignatureModifiers.Class; + protected override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new List { Client.Type, ResponseModelType, NextPagePropertyType }; + if (ItemModelType != null) + { + dependencies.Add(ItemModelType); + } + + foreach (var field in RequestFields) + { + dependencies.Add(field.Type); + } + + return dependencies; + } + protected override FieldProvider[] BuildFields() => [ClientField, .. RequestFields]; protected override CSharpType[] BuildImplements() => diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs index 5d76b8f44a1..2b50fff372e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs @@ -67,7 +67,7 @@ private MethodProvider BuildXmlModelWriteCoreMethod() MethodSignatureModifiers modifiers = _isStruct ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Internal | MethodSignatureModifiers.Virtual; - if (_shouldOverrideXmlMethods) + if (_shouldOverrideMethods) { modifiers = MethodSignatureModifiers.Internal | MethodSignatureModifiers.Override; } @@ -81,7 +81,7 @@ private MethodProvider BuildXmlModelWriteCoreMethod() private MethodBodyStatement[] BuildXmlModelWriteCoreMethodBody() { - var categorizedProperties = _shouldOverrideXmlMethods + var categorizedProperties = _shouldOverrideMethods ? CategorizedXmlProperties : AllCategorizedXmlProperties; var statements = new List @@ -90,7 +90,7 @@ private MethodBodyStatement[] BuildXmlModelWriteCoreMethodBody() MethodBodyStatement.EmptyLine }; - if (_shouldOverrideXmlMethods) + if (_shouldOverrideMethods) { statements.Add(Base.Invoke(XmlModelWriteCoreMethodName, _xmlWriterParameter, _serializationOptionsParameter).Terminate()); } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index 8202a2af405..97532886c40 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -53,16 +53,10 @@ public partial class MrwSerializationTypeDefinition : TypeProvider private readonly ScopedApi _mrwOptionsParameterSnippet; private readonly ScopedApi _jsonElementParameterSnippet; private readonly ScopedApi _isNotEqualToWireConditionSnippet; - // These interface types depend on _model.Type. Build them lazily so we do not cache a - // CSharpType before delayed base model resolution has updated the model's inheritance. - private CSharpType? _jsonModelTInterfaceValue; - private CSharpType _jsonModelTInterface => _jsonModelTInterfaceValue ??= new CSharpType(typeof(IJsonModel<>), SerializationInterfaceType.Type); - private CSharpType? _jsonModelObjectInterface; - private CSharpType? JsonModelObjectInterface => _isStruct ? _jsonModelObjectInterface ??= (CSharpType)typeof(IJsonModel) : null; - private CSharpType? _persistableModelTInterfaceValue; - private CSharpType _persistableModelTInterface => _persistableModelTInterfaceValue ??= new CSharpType(typeof(IPersistableModel<>), SerializationInterfaceType.Type); - private CSharpType? _persistableModelObjectInterface; - private CSharpType? PersistableModelObjectInterface => _isStruct ? _persistableModelObjectInterface ??= (CSharpType)typeof(IPersistableModel) : null; + private readonly CSharpType _jsonModelTInterface; + private readonly CSharpType? _jsonModelObjectInterface; + private readonly CSharpType _persistableModelTInterface; + private readonly CSharpType? _persistableModelObjectInterface; private readonly ModelProvider _model; private readonly InputModelType _inputModel; private readonly FieldProvider? _rawDataField; @@ -73,20 +67,10 @@ public partial class MrwSerializationTypeDefinition : TypeProvider private readonly bool _supportsXml; private ConstructorProvider? _serializationConstructor; // Flag to determine if the model should override the serialization methods - private bool? _shouldOverrideMethods; - private bool ShouldOverrideMethods => _shouldOverrideMethods ??= _model.BaseModelProvider != null && !_isStruct; - private bool? _shouldSkipSerializationMethodOverrides; - private bool ShouldSkipSerializationMethodOverrides => _shouldSkipSerializationMethodOverrides ??= ShouldSkipDerivedSerializationMethodOverrides(_model.BaseModelProvider); - private readonly bool _shouldOverrideXmlMethods; + private readonly bool _shouldOverrideMethods; + private readonly bool _shouldSkipDerivedSerializationMethodOverrides; private readonly Lazy _additionalProperties; - // Unknown discriminator models use their base model as the serialization interface type. - // This can also touch model.Type, so defer it until serialization method/interface emission. - private TypeProvider SerializationInterfaceType => _serializationInterfaceType ??= _inputModel.IsUnknownDiscriminatorModel - ? ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(_inputModel.BaseModel!)! - : _model; - private TypeProvider? _serializationInterfaceType; - private CSharpType RootType => _rootType ??= GetRootModelType(); private CSharpType? _rootType; @@ -100,10 +84,17 @@ public MrwSerializationTypeDefinition(InputModelType inputModel, ModelProvider m _isStruct = _model.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Struct); _supportsXml = inputModel.Usage.HasFlag(InputModelTypeUsage.Xml); _supportsJson = inputModel.Usage.HasFlag(InputModelTypeUsage.Json) || !_supportsXml; - _shouldOverrideXmlMethods = _model.BaseModelProvider != null && !_isStruct; + // Initialize the serialization interfaces + var interfaceType = inputModel.IsUnknownDiscriminatorModel ? ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(inputModel.BaseModel!)! : _model; + _jsonModelTInterface = new CSharpType(typeof(IJsonModel<>), interfaceType.Type); + _jsonModelObjectInterface = _isStruct ? (CSharpType)typeof(IJsonModel) : null; + _persistableModelTInterface = new CSharpType(typeof(IPersistableModel<>), interfaceType.Type); + _persistableModelObjectInterface = _isStruct ? (CSharpType)typeof(IPersistableModel) : null; _rawDataField = _model.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); _additionalBinaryDataProperty = new(GetAdditionalBinaryDataPropertiesProp); _additionalProperties = new(() => [.. _model.Properties.Where(p => p.IsAdditionalProperties)]); + _shouldOverrideMethods = _model.BaseModelProvider != null && !_isStruct; + _shouldSkipDerivedSerializationMethodOverrides = ShouldSkipDerivedSerializationMethodOverrides(_model.BaseModelProvider); _utf8JsonWriterSnippet = _utf8JsonWriterParameter.As(); _mrwOptionsParameterSnippet = _serializationOptionsParameter.As(); _jsonElementParameterSnippet = _jsonElementDeserializationParam.As(); @@ -126,6 +117,10 @@ public MrwSerializationTypeDefinition(InputModelType inputModel, ModelProvider m protected override CSharpType? BuildBaseType() => _model.BaseType; + protected override IReadOnlyList BuildHelperDependencyTypes() => _rawDataField != null || _additionalProperties.Value.Length > 0 + ? [ScmCodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType] + : []; + protected override SuppressionStatement[] BuildDisabledFileWarnings() { if (_model.CanonicalView.Properties.Any(p => ScmModelProvider.IsFileBinaryContentType(p.Type))) @@ -438,19 +433,17 @@ protected override CSharpType[] BuildImplements() if (_supportsJson) { interfaces.Add(_jsonModelTInterface); - var jsonModelObjectInterface = JsonModelObjectInterface; - if (jsonModelObjectInterface != null) + if (_jsonModelObjectInterface != null) { - interfaces.Add(jsonModelObjectInterface); + interfaces.Add(_jsonModelObjectInterface); } } else if (_supportsXml) { interfaces.Add(_persistableModelTInterface); - var persistableModelObjectInterface = PersistableModelObjectInterface; - if (persistableModelObjectInterface != null) + if (_persistableModelObjectInterface != null) { - interfaces.Add(persistableModelObjectInterface); + interfaces.Add(_persistableModelObjectInterface); } } @@ -480,7 +473,7 @@ internal MethodProvider BuildJsonModelWriteMethodObjectDeclaration() var castToT = This.CastTo(_jsonModelTInterface); return new MethodProvider ( - new MethodSignature(nameof(IJsonModel.Write), null, MethodSignatureModifiers.None, null, null, [_utf8JsonWriterParameter, _serializationOptionsParameter], ExplicitInterface: JsonModelObjectInterface), + new MethodSignature(nameof(IJsonModel.Write), null, MethodSignatureModifiers.None, null, null, [_utf8JsonWriterParameter, _serializationOptionsParameter], ExplicitInterface: _jsonModelObjectInterface), castToT.Invoke(nameof(IJsonModel.Write), [_utf8JsonWriterParameter, _serializationOptionsParameter]), this ); @@ -495,7 +488,7 @@ internal MethodProvider BuildJsonModelCreateMethodObjectDeclaration() var castToT = This.CastTo(_jsonModelTInterface); return new MethodProvider ( - new MethodSignature(nameof(IJsonModel.Create), null, MethodSignatureModifiers.None, typeof(object), null, [_utf8JsonReaderParameter, _serializationOptionsParameter], ExplicitInterface: JsonModelObjectInterface), + new MethodSignature(nameof(IJsonModel.Create), null, MethodSignatureModifiers.None, typeof(object), null, [_utf8JsonReaderParameter, _serializationOptionsParameter], ExplicitInterface: _jsonModelObjectInterface), castToT.Invoke(nameof(IJsonModel.Create), [_utf8JsonReaderParameter.AsArgument(), _serializationOptionsParameter]), this ); @@ -511,7 +504,7 @@ internal MethodProvider BuildPersistableModelWriteMethodObjectDeclaration() var returnType = typeof(BinaryData); return new MethodProvider ( - new MethodSignature(nameof(IPersistableModel.Write), null, MethodSignatureModifiers.None, returnType, null, [_serializationOptionsParameter], ExplicitInterface: PersistableModelObjectInterface), + new MethodSignature(nameof(IPersistableModel.Write), null, MethodSignatureModifiers.None, returnType, null, [_serializationOptionsParameter], ExplicitInterface: _persistableModelObjectInterface), castToT.Invoke(nameof(IPersistableModel.Write), [_serializationOptionsParameter]), this ); @@ -527,7 +520,7 @@ internal MethodProvider BuildPersistableModelCreateMethodObjectDeclaration() var returnType = typeof(object); return new MethodProvider ( - new MethodSignature(nameof(IPersistableModel.Create), null, MethodSignatureModifiers.None, returnType, null, [_dataParameter, _serializationOptionsParameter], ExplicitInterface: PersistableModelObjectInterface), + new MethodSignature(nameof(IPersistableModel.Create), null, MethodSignatureModifiers.None, returnType, null, [_dataParameter, _serializationOptionsParameter], ExplicitInterface: _persistableModelObjectInterface), castToT.Invoke(nameof(IPersistableModel.Create), [_dataParameter, _serializationOptionsParameter]), this ); @@ -541,7 +534,7 @@ internal MethodProvider BuildJsonModelWriteCoreMethod() MethodSignatureModifiers modifiers = _isStruct ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (ShouldOverrideMethods) + if (_shouldOverrideMethods) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -563,7 +556,7 @@ internal MethodProvider BuildPersistableModelWriteCoreMethod() ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (ShouldOverrideMethods && !ShouldSkipSerializationMethodOverrides) + if (_shouldOverrideMethods && !_shouldSkipDerivedSerializationMethodOverrides) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -587,7 +580,7 @@ internal MethodProvider BuildPersistableModelCreateCoreMethod() ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (ShouldOverrideMethods && !ShouldSkipSerializationMethodOverrides) + if (_shouldOverrideMethods && !_shouldSkipDerivedSerializationMethodOverrides) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -635,7 +628,7 @@ internal MethodProvider BuildJsonModelCreateCoreMethod() ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (ShouldOverrideMethods && !ShouldSkipSerializationMethodOverrides) + if (_shouldOverrideMethods && !_shouldSkipDerivedSerializationMethodOverrides) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -807,7 +800,7 @@ internal MethodProvider BuildPersistableModelGetFormatFromOptionsObjectDeclarati // string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => ((IPersistableModel)this).GetFormatFromOptions(options); return new MethodProvider ( - new MethodSignature(nameof(IPersistableModel.GetFormatFromOptions), null, MethodSignatureModifiers.None, typeof(string), null, [_serializationOptionsParameter], ExplicitInterface: PersistableModelObjectInterface), + new MethodSignature(nameof(IPersistableModel.GetFormatFromOptions), null, MethodSignatureModifiers.None, typeof(string), null, [_serializationOptionsParameter], ExplicitInterface: _persistableModelObjectInterface), castToT.Invoke(nameof(IPersistableModel.GetFormatFromOptions), [_serializationOptionsParameter]), this ); @@ -1066,7 +1059,7 @@ private MethodBodyStatement[] BuildPersistableModelCreateCoreMethodBody() private MethodBodyStatement CallBaseJsonModelWriteCore(bool isDynamicModelWithNonDynamicBase) { // base.() - bool callBaseWriteMethod = ShouldOverrideMethods + bool callBaseWriteMethod = _shouldOverrideMethods && (_jsonPatchProperty is null || !isDynamicModelWithNonDynamicBase); return callBaseWriteMethod ? Base.Invoke(JsonModelWriteCoreMethodName, [_utf8JsonWriterParameter, _serializationOptionsParameter]).Terminate() diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs index 88cb97b16e7..696d654dd9c 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs @@ -51,6 +51,12 @@ protected override string BuildRelativeFilePath() return Path.Combine("src", "Generated", "Models", $"{Name}.Serialization.Multipart.cs"); } + protected override IReadOnlyList BuildHelperDependencyTypes() => _model.Properties.Any( + prop => prop.WireInfo != null && !prop.WireInfo.IsRequired && + (prop.Type is { IsCollection: true, IsReadOnlyMemory: false } || prop.Type.IsDictionary)) + ? [ScmCodeModelGenerator.Instance.TypeFactory.OptionalType] + : []; + protected override SuppressionStatement[] BuildDisabledFileWarnings() => [new SuppressionStatement(null, Literal(ScmModelProvider.FileBinaryContentDiagnosticId), ScmModelProvider.ScmEvaluationTypeSuppressionJustification)]; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs index f00d78571d0..5ccee318677 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs @@ -78,6 +78,44 @@ protected override FieldProvider[] BuildFields() return [.. pipelineMessage20xClassifiersFields]; } + protected override IReadOnlyList BuildHelperDependencyTypes() + { + var dependencies = new List { new ClientUriBuilderDefinition().Type }; + foreach (var serviceMethod in _inputClient.Methods) + { + foreach (var parameter in serviceMethod.Operation.Parameters) + { + if (IsGeneratedContentTypeMethodParameter(parameter) || + parameter is not InputHeaderParameter and not InputQueryParameter) + { + continue; + } + + var type = ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(parameter.Type); + if (type?.IsDictionary == true) + { + AddDependency(dependencies, ScmCodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType); + } + else if (type?.IsCollection == true) + { + AddDependency(dependencies, ScmCodeModelGenerator.Instance.TypeFactory.ListInitializationType); + } + } + } + + return dependencies; + } + + private static void AddDependency(List dependencies, CSharpType dependency) + { + if (!dependencies.Any(existing => + existing.Name == dependency.Name && + existing.Namespace == dependency.Namespace)) + { + dependencies.Add(dependency); + } + } + protected override ScmMethodProvider[] BuildMethods() { List methods = new List(); @@ -549,18 +587,6 @@ private static MethodBodyStatement BuildAppendQueryStatement( { if (paramType?.IsCollection != true) { - // A model-typed query parameter marked with `explode` must be expanded into one query - // entry per property (RFC 6570 form explode, e.g. `?field=status&value=active`) rather - // than serialized via the object's ToString (which previously produced the type name). - if (inputQueryParameter.Explode && inputQueryParameter.Type is InputModelType inputModel) - { - var explodeStatement = BuildExplodeModelQueryStatement(uri, inputModel, valueExpression); - if (explodeStatement != null) - { - return explodeStatement; - } - } - var toStringExpression = GetQueryParameterStringExpression(paramType, valueExpression, serializationFormat); return uri.AppendQuery(Literal(inputQueryParameter.SerializedName), toStringExpression, true).Terminate(); } @@ -617,70 +643,6 @@ private static MethodBodyStatement BuildAppendQueryStatement( } } - /// - /// Builds the statements for a model-typed query parameter that uses form-style `explode`. - /// Each (simple) property of the model is emitted as its own query entry using the property's - /// wire name (RFC 6570 form explode, e.g. ?field=status&value=active). - /// Returns null when the model contains a property that is not a simple scalar/enum - /// (e.g. a nested object or a collection), in which case the caller falls back to the default - /// handling. Nested/complex expansion is tracked separately (see issue #11123). - /// - private static MethodBodyStatement? BuildExplodeModelQueryStatement( - ScopedApi uri, - InputModelType inputModel, - ValueExpression valueExpression) - { - var modelProvider = ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(inputModel); - if (modelProvider is null) - { - return null; - } - - var properties = modelProvider.CanonicalView.Properties; - if (properties.Count == 0) - { - return null; - } - - // Only expand when every property is a simple scalar or enum. Nested objects and - // collections are not defined by RFC 6570 form explode and require a separate design - // decision, so we fall back to the default handling for those. - foreach (var property in properties) - { - if (property.WireInfo is null || - property.Type.IsCollection || - (!property.Type.IsFrameworkType && !property.Type.IsEnum)) - { - return null; - } - } - - var statements = new List(); - foreach (var property in properties) - { - var propertyAccess = valueExpression.Property(property.Name); - var propertyType = property.Type; - - ValueExpression convertedValue = propertyType.IsEnum - ? propertyType.ToSerial(propertyAccess).ConvertToString() - : GetQueryParameterStringExpression(propertyType, propertyAccess, property.SerializationFormat); - - MethodBodyStatement appendStatement = - uri.AppendQuery(Literal(property.WireInfo!.SerializedName), convertedValue, true).Terminate(); - - if (!property.WireInfo.IsRequired || - propertyType.IsNullable || - (propertyType is { IsValueType: false, IsFrameworkType: true } && propertyType.FrameworkType != typeof(string))) - { - appendStatement = BuildQueryOrHeaderOrPathParameterNullCheck(propertyType, propertyAccess, appendStatement); - } - - statements.Add(appendStatement); - } - - return statements; - } - private static IfStatement BuildQueryOrHeaderOrPathParameterNullCheck( CSharpType? parameterType, ValueExpression valueExpression, @@ -919,7 +881,9 @@ private static void AppendLiteralSegment(ScopedApi uri, string literal, List paramMap, InputOperation operation, InputParameter inputParam, out CSharpType? type, out SerializationFormat? serializationFormat, out ValueExpression? valueExpression) { - type = ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(inputParam.Type); + type = IsGeneratedContentTypeMethodParameter(inputParam) + ? null + : ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(inputParam.Type); serializationFormat = null; if (inputParam.IsApiVersion && ClientProvider.IsMultiServiceClient) @@ -1208,7 +1172,10 @@ internal static List GetMethodParameters( // when one was already published. UpdateParameterNameWithBackCompat(inputParam, inputParam.Name, client.BackCompatProvider, serviceMethod); - ParameterProvider? parameter = ScmCodeModelGenerator.Instance.TypeFactory.CreateParameter(inputParam)?.ToPublicInputParameter(); + ParameterProvider? parameter = IsGeneratedContentTypeMethodParameter(inputParam) && + methodType is ScmMethodKind.Protocol or ScmMethodKind.CreateRequest + ? CreateContentTypeParameter(inputParam) + : ScmCodeModelGenerator.Instance.TypeFactory.CreateParameter(inputParam)?.ToPublicInputParameter(); if (parameter is null) { continue; @@ -1249,7 +1216,7 @@ internal static List GetMethodParameters( break; case ParameterLocation.Query: case ParameterLocation.Header: - if (inputParam is InputHeaderParameter { IsContentType: true } + if (IsContentTypeParameter(inputParam) && !HasContentTypeBeforeBodyInLastContract(serviceMethod.Name, client.BackCompatProvider)) { sortedParams.Add(contentType++, parameter); @@ -1292,12 +1259,25 @@ internal static List GetMethodParameters( return [.. sortedParams.Values]; } + private static ParameterProvider CreateContentTypeParameter(InputParameter inputParam) + { + var type = new CSharpType(typeof(string), isNullable: !inputParam.IsRequired); + return new ParameterProvider( + inputParam.Name, + DocHelpers.GetFormattableDescription(inputParam.Summary, inputParam.Doc) ?? FormattableStringHelpers.Empty, + type, + defaultValue: inputParam.IsRequired ? null : Default, + location: ParameterLocation.Header, + wireInfo: new WireInformation(SerializationFormat.Default, inputParam.SerializedName), + validation: inputParam.IsRequired ? ParameterValidationType.AssertNotNullOrEmpty : ParameterValidationType.None, + inputParameter: inputParam); + } + private static bool HasLiteralContentTypeHeader(InputOperation operation) { foreach (var p in operation.Parameters) { - if (p is InputHeaderParameter { IsContentType: true } header - && header.Type is InputLiteralType) + if (p is InputHeaderParameter { IsContentType: true } && p.Type is InputLiteralType) { return true; } @@ -1305,6 +1285,14 @@ private static bool HasLiteralContentTypeHeader(InputOperation operation) return false; } + private static bool IsGeneratedContentTypeMethodParameter(InputParameter parameter) => + parameter is InputMethodParameter { Location: InputRequestLocation.Header } && + string.Equals(parameter.SerializedName, "Content-Type", StringComparison.OrdinalIgnoreCase); + + private static bool IsContentTypeParameter(InputParameter parameter) => + parameter is InputHeaderParameter { IsContentType: true } || + IsGeneratedContentTypeMethodParameter(parameter); + /// /// Checks if the last contract view contains a method matching the given name where /// a "contentType" parameter appears before the body ("content") parameter. diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/SerializationFormatDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/SerializationFormatDefinition.cs index af294640060..e902afd7c79 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/SerializationFormatDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/SerializationFormatDefinition.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.Collections.Generic; using System.IO; using System.Linq; @@ -45,6 +46,7 @@ protected override TypeSignatureModifiers BuildDeclarationModifiers() protected override string BuildRelativeFilePath() => Path.Combine("src", "Generated", "Internal", $"{Name}.cs"); protected override string BuildName() => "SerializationFormat"; + protected override FormattableString BuildDescription() => $"The serialization format."; protected override TypeProvider[] BuildSerializationProviders() => []; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Snippets/HttpRequestApiSnippets.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Snippets/HttpRequestApiSnippets.cs index 588a2094b12..fdd8f72fba0 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Snippets/HttpRequestApiSnippets.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Snippets/HttpRequestApiSnippets.cs @@ -26,7 +26,9 @@ public static MethodBodyStatement SetContent(this ScopedApi pip public static MethodBodyStatement SetHeaderDelimited(this HttpRequestApi pipelineRequest, string name, ValueExpression value, ValueExpression delimiter, ValueExpression? format = null) { ValueExpression[] parameters = format != null ? [Literal(name), value, delimiter, format] : [Literal(name), value, delimiter]; - return pipelineRequest.Property(nameof(PipelineRequest.Headers)).Invoke("SetDelimited", parameters).Terminate(); + return pipelineRequest.Property(nameof(PipelineRequest.Headers)) + .Invoke("SetDelimited", parameters, typeArguments: null, callAsAsync: false, extensionType: new PipelineRequestHeadersExtensionsDefinition().Type) + .Terminate(); } } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs new file mode 100644 index 00000000000..e5a47dbc62e --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs @@ -0,0 +1,734 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.TypeSpec.Generator.Input; +using Microsoft.TypeSpec.Generator.Tests.Common; +using NUnit.Framework; + +namespace Microsoft.TypeSpec.Generator.ClientModel.Tests.PostProcessing +{ + public class ClientBodyDependencyPostProcessingTests + { + [Test] + public async Task OperationBodyParameterModelDoesNotBecomePublic() + { + var requestModel = InputFactory.Model("RequestBody"); + var parameter = InputFactory.BodyParameter("body", requestModel, isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Create", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertInternalModels([requestModel], [client], ["RequestBody"]); + } + + [Test] + public async Task OperationResponseBodyModelRemainsPublicAsRootOutputModel() + { + var responseModel = InputFactory.Model("ResponseBody"); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertPublicModels([responseModel], [client], ["ResponseBody"]); + } + + [Test] + public async Task OperationResponseBodyModelIsRemovedWhenNotOtherwiseReferenced() + { + var metadataOnlyModel = InputFactory.Model("MetadataOnlyResponse"); + var operation = InputFactory.Operation( + "Get", + responses: [ + InputFactory.OperationResponse(bodytype: InputPrimitiveType.String), + new InputOperationResponse([202], metadataOnlyModel, [], isErrorResponse: false, ["application/json"]) + ]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [metadataOnlyModel], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponse.cs"), + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponse.Serialization.cs") + ]); + } + + [Test] + public async Task InternalAdditionalRootModelIsRemovedWhenNotOtherwiseReferenced() + { + var metadataOnlyModel = InputFactory.Model("MetadataOnlyResponse", access: "internal"); + var operation = InputFactory.Operation( + "Get", + responses: [ + InputFactory.OperationResponse(bodytype: InputPrimitiveType.String), + new InputOperationResponse([202], metadataOnlyModel, [], isErrorResponse: false, ["application/json"]) + ]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [metadataOnlyModel], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponse.cs"), + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponse.Serialization.cs") + ], + configureGenerator: () => + { + var provider = CodeModelGenerator.Instance.OutputLibrary.TypeProviders.Single(provider => provider.Name == "MetadataOnlyResponse"); + CodeModelGenerator.Instance.AddTypeToKeep(provider); + }); + } + + [Test] + public async Task AdditionalRootEnumIsRemovedWhenNotOtherwiseReferenced() + { + var metadataOnlyEnum = InputFactory.StringEnum( + "MetadataOnlyResponseKind", + [("Accepted", "accepted")]); + var operation = InputFactory.Operation( + "Get", + responses: [ + InputFactory.OperationResponse(bodytype: InputPrimitiveType.String), + new InputOperationResponse([202], metadataOnlyEnum, [], isErrorResponse: false, ["application/json"]) + ]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [metadataOnlyEnum], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponseKind.cs"), + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponseKind.Serialization.cs") + ]); + } + + [Test] + public async Task ContentTypeHeaderEnumIsRemovedWhenNotOtherwiseReferenced() + { + var contentTypeEnum = InputFactory.StringEnum( + "UpdateSnapshotRequestContentType", + [ + ("ApplicationMergePatchJson", "application/merge-patch+json"), + ("ApplicationJson", "application/json") + ]); + var contentTypeParameter = InputFactory.MethodParameter( + "contentType", + InputFactory.Union([contentTypeEnum], "contentType"), + isRequired: true, + location: InputRequestLocation.Header, + serializedName: "Content-Type"); + var operation = InputFactory.Operation( + "UpdateSnapshot", + parameters: [contentTypeParameter], + httpMethod: "PATCH", + generateConvenienceMethod: false); + var method = InputFactory.BasicServiceMethod( + "UpdateSnapshot", + operation, + parameters: [ + InputFactory.MethodParameter("name", InputPrimitiveType.String, isRequired: true, location: InputRequestLocation.Path), + contentTypeParameter, + InputFactory.MethodParameter("content", InputPrimitiveType.Base64, isRequired: true) + ]); + var client = InputFactory.Client("ConfigurationClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [contentTypeEnum], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "UpdateSnapshotRequestContentType.cs"), + Path.Combine("src", "Generated", "Models", "UpdateSnapshotRequestContentType.Serialization.cs") + ], + configureGenerator: () => + CodeModelGenerator.Instance.TypeFactory.CreateCSharpType(InputFactory.Union([contentTypeEnum], "contentType"))); + } + + [Test] + public async Task PublicEnumIsRemovedWhenNotOtherwiseReferenced() + { + var metadataOnlyEnum = InputFactory.StringEnum( + "MetadataOnlyKind", + [("One", "one")]); + + await GenerateAndAssertFiles( + enums: [metadataOnlyEnum], + models: [], + clients: [], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "MetadataOnlyKind.cs"), + Path.Combine("src", "Generated", "Models", "MetadataOnlyKind.Serialization.cs") + ]); + } + + [Test] + public async Task ContentTypeHeaderEnumReferencedByCustomSuppressionIsKept() + { + var contentTypeEnum = InputFactory.StringEnum( + "PutKeyValueRequestContentType", + [("ApplicationJson", "application/json")], + isExtensible: true); + var contentTypeParameter = InputFactory.HeaderParameter( + "contentType", + InputFactory.Union([contentTypeEnum], "contentType"), + isRequired: true, + isContentType: true, + serializedName: "Content-Type"); + var operation = InputFactory.Operation( + "SetConfigurationSettingInternal", + parameters: [contentTypeParameter], + httpMethod: "PUT"); + var method = InputFactory.BasicServiceMethod("SetConfigurationSettingInternal", operation); + var client = InputFactory.Client("ConfigurationClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [contentTypeEnum], + models: [], + clients: [client], + customFiles: [ + (Path.Combine("src", "PutKeyValueRequestContentType.cs"), """ + namespace Sample.Models; + + internal readonly partial struct PutKeyValueRequestContentType + { + public static PutKeyValueRequestContentType ApplicationJson { get; } = new PutKeyValueRequestContentType("application/json"); + } + """) + ], + expectedFiles: [ + Path.Combine("src", "Generated", "Models", "PutKeyValueRequestContentType.cs") + ]); + } + + [Test] + public async Task ContentTypeHeaderEnumReferencedOnlyByCustomSuppressionAttributeIsKept() + { + var contentTypeEnum = InputFactory.StringEnum( + "CreateSnapshotRequestContentType", + [("ApplicationJson", "application/json")], + isExtensible: true); + var contentTypeParameter = InputFactory.MethodParameter( + "contentType", + InputFactory.Union([contentTypeEnum], "contentType"), + isRequired: true, + location: InputRequestLocation.Header, + serializedName: "Content-Type"); + var operation = InputFactory.Operation( + "CreateSnapshot", + parameters: [contentTypeParameter], + httpMethod: "PUT", + generateConvenienceMethod: false); + var method = InputFactory.BasicServiceMethod( + "CreateSnapshot", + operation, + parameters: [ + InputFactory.MethodParameter("name", InputPrimitiveType.String, isRequired: true, location: InputRequestLocation.Path), + contentTypeParameter, + InputFactory.MethodParameter("content", InputPrimitiveType.Base64, isRequired: true) + ]); + var client = InputFactory.Client("ConfigurationClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [contentTypeEnum], + models: [], + clients: [client], + customFiles: [ + (Path.Combine("src", "ConfigurationClient.cs"), """ + namespace Sample; + + [Microsoft.TypeSpec.Generator.Customizations.CodeGenType("ConfigurationClient")] + [Microsoft.TypeSpec.Generator.Customizations.CodeGenSuppress("CreateSnapshot", typeof(string), typeof(CreateSnapshotRequestContentType))] + public partial class ConfigurationClient + { + } + """) + ], + expectedFiles: [ + Path.Combine("src", "Generated", "Models", "CreateSnapshotRequestContentType.cs") + ]); + } + + [Test] + public async Task NestedBodyModelGraphDoesNotBecomePublic() + { + var nestedModel = InputFactory.Model("NestedToolParameter"); + var toolModel = InputFactory.Model( + "ToolConfig", + properties: [InputFactory.Property("Parameter", nestedModel)]); + var parameter = InputFactory.BodyParameter("tool", toolModel, isRequired: true); + var operation = InputFactory.Operation("Configure", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Configure", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertInternalModels([toolModel, nestedModel], [client], ["ToolConfig", "NestedToolParameter"]); + } + + [Test] + public async Task NonDiscriminatorDerivedBodyModelDoesNotBecomePublicFromPublicBase() + { + var baseTool = InputFactory.Model("BaseTool"); + var concreteTool = InputFactory.Model( + "ConcreteTool", + properties: [InputFactory.Property("Name", InputPrimitiveType.String)], + baseModel: baseTool); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: baseTool)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(baseTool, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertMixedModels( + [baseTool, concreteTool], + [client], + publicModelNames: ["BaseTool"], + internalModelNames: ["ConcreteTool"]); + } + + [Test] + public async Task PublicModelSignatureDependencyIsPromotedToPublic() + { + var internalDependency = InputFactory.Model("InternalDependency", access: "internal"); + var responseModel = InputFactory.Model( + "ResponseBody", + properties: [InputFactory.Property("Dependency", internalDependency)]); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(responseModel, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertPublicModels([responseModel, internalDependency], [client], ["ResponseBody", "InternalDependency"]); + } + + [Test] + public async Task AzureClientPublicMethodSignatureReferencesStayPublic() + { + var signatureModel = InputFactory.Model("SignatureModel", @namespace: "Azure.Sample.Models"); + var methodParameter = InputFactory.MethodParameter("signature", signatureModel, isRequired: true); + var operation = InputFactory.Operation( + "Create", + parameters: [InputFactory.BodyParameter("signature", signatureModel, isRequired: true)], + httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Create", operation, parameters: [methodParameter]); + var client = InputFactory.Client("SampleClient", clientNamespace: "Azure.Sample", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [signatureModel], + clients: [client], + customFiles: [], + expectedFiles: [], + publicModelNames: ["SignatureModel"], + packageName: "Azure.Sample"); + } + + [Test] + public async Task BasePreservedDerivedModelTraversesTransitiveDependencies() + { + var transitiveDependency = InputFactory.Model("TransitiveDependency"); + var dependency = InputFactory.Model( + "DerivedDependency", + properties: [InputFactory.Property("Transitive", transitiveDependency)]); + var baseModel = InputFactory.Model("BaseResult"); + var derivedModel = InputFactory.Model( + "DerivedResult", + properties: [InputFactory.Property("Dependency", dependency)], + baseModel: baseModel); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: baseModel)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(baseModel, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [baseModel, derivedModel, dependency, transitiveDependency], + clients: [client], + customFiles: [], + expectedFiles: [], + publicModelNames: ["BaseResult"], + internalModelNames: ["DerivedResult", "DerivedDependency", "TransitiveDependency"]); + } + + [Test] + public async Task PublicCustomCodeArraySignatureReferencesStayPublic() + { + var generatedModel = InputFactory.Model("GeneratedModel"); + + await GenerateAndAssertFiles( + enums: [], + models: [generatedModel], + clients: [], + customFiles: [ + (Path.Combine("src", "PublicCustomApi.cs"), """ + using Sample.Models; + + namespace Sample; + + public partial class PublicCustomApi + { + public GeneratedModel[] Items { get; } = System.Array.Empty(); + } + """) + ], + expectedFiles: [], + publicModelNames: ["GeneratedModel"]); + } + + [Test] + public async Task GeneratedRequestHeaderSetDelimitedReferenceKeepsExtensions() + { + var header = InputFactory.HeaderParameter("x-ms-custom", InputFactory.Array(InputPrimitiveType.String), isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [header]); + var method = InputFactory.BasicServiceMethod("Create", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [Path.Combine("src", "Generated", "Internal", "PipelineRequestHeadersExtensions.cs")]); + } + + [Test] + public async Task BinaryDataBodyParameterDoesNotKeepBinaryContentHelpers() + { + var parameter = InputFactory.BodyParameter( + "content", + InputPrimitiveType.Base64, + isRequired: true, + contentTypes: ["application/octet-stream"], + defaultContentType: "application/octet-stream"); + var operation = InputFactory.Operation("Upload", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod( + "Upload", + operation, + parameters: [InputFactory.MethodParameter("content", InputPrimitiveType.Base64, isRequired: true)]); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Internal", "BinaryContentHelper.cs"), + Path.Combine("src", "Generated", "Internal", "Utf8JsonBinaryContent.cs") + ]); + } + + [Test] + public async Task CollectionBodyParameterKeepsBinaryContentHelpers() + { + var parameter = InputFactory.BodyParameter("items", InputFactory.Array(InputPrimitiveType.String), isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod( + "Create", + operation, + parameters: [InputFactory.MethodParameter("items", InputFactory.Array(InputPrimitiveType.String), isRequired: true)]); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [ + Path.Combine("src", "Generated", "Internal", "BinaryContentHelper.cs"), + Path.Combine("src", "Generated", "Internal", "Utf8JsonBinaryContent.cs") + ]); + } + + [Test] + public async Task CustomOnlyRequestHeaderSetDelimitedReferenceKeepsExtensions() + { + await GenerateAndAssertFiles( + enums: [], + models: [], + clients: [], + customFiles: [ + (Path.Combine("src", "CustomHeaders.cs"), """ + using System.ClientModel.Primitives; + + namespace Sample; + + public static class CustomHeaders + { + public static void Add(PipelineRequestHeaders headers, string[] values) + => headers.SetDelimited("x-ms-custom", values, ","); + } + """) + ], + expectedFiles: [Path.Combine("src", "Generated", "Internal", "PipelineRequestHeadersExtensions.cs")]); + } + + [Test] + public async Task CustomizedEnumSerializationProviderIsKeptWhenModelSerializationUsesEnum() + { + var statusEnum = InputFactory.StringEnum( + "Status", + [("Succeeded", "succeeded"), ("Failed", "failed")], + clientNamespace: "Sample"); + var resultModel = InputFactory.Model( + "OperationResult", + properties: [InputFactory.Property("Status", statusEnum, isRequired: true)], + @namespace: "Sample"); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: resultModel)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(resultModel, [])); + var client = InputFactory.Client("TestClient", methods: [method], clientNamespace: "Sample"); + + await GenerateAndAssertFiles( + enums: [statusEnum], + models: [resultModel], + clients: [client], + customFiles: [ + (Path.Combine("src", "Custom", "Status.cs"), """ + namespace Sample; + + [CodeGenType("Status")] + public enum Status + { + Succeeded, + Failed + } + """) + ], + expectedFiles: [Path.Combine("src", "Generated", "Models", "Status.Serialization.cs")]); + } + + [Test] + public async Task CustomModelFactoryPartialDoesNotKeepBodyOnlyModelPublic() + { + var requestModel = InputFactory.Model("RequestBody"); + var parameter = InputFactory.BodyParameter("body", requestModel, isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Create", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [requestModel], + clients: [client], + customFiles: [ + (Path.Combine("src", "SampleModelFactory.cs"), """ + namespace Sample; + + [Microsoft.TypeSpec.Generator.Customizations.CodeGenType("SampleModelFactory")] + public static partial class SampleModelFactory + { + } + """) + ], + expectedFiles: [], + internalModelNames: ["RequestBody"]); + } + + [Test] + public async Task InternalCustomClientPartialOverridesLastContractPublicClient() + { + var responseModel = InputFactory.Model("CompactResource"); + var operation = InputFactory.Operation("Compact", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod("Compact", operation, response: InputFactory.ServiceMethodResponse(responseModel, [])); + var client = InputFactory.Client("Responses", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [responseModel], + clients: [client], + customFiles: [ + (Path.Combine("src", "Generated", "Responses.cs"), """ + namespace Sample; + + public partial class Responses + { + } + """), + (Path.Combine("src", "Custom", "Internal", "Responses.cs"), """ + namespace Sample; + + internal partial class Responses + { + } + """) + ], + expectedFiles: [], + internalModelNames: ["CompactResource"], + internalClientNames: ["Responses"]); + } + + private static async Task GenerateAndAssertInternalModels( + InputModelType[] models, + InputClient[] clients, + string[] modelNames) + => await GenerateAndAssertModels(models, clients, modelNames, shouldBePublic: false); + + private static async Task GenerateAndAssertPublicModels( + InputModelType[] models, + InputClient[] clients, + string[] modelNames) + => await GenerateAndAssertModels(models, clients, modelNames, shouldBePublic: true); + + private static async Task GenerateAndAssertMixedModels( + InputModelType[] models, + InputClient[] clients, + string[] publicModelNames, + string[] internalModelNames) + => await GenerateAndAssertModels(models, clients, publicModelNames, internalModelNames); + + private static async Task GenerateAndAssertModels( + InputModelType[] models, + InputClient[] clients, + string[] modelNames, + bool shouldBePublic) + => await GenerateAndAssertModels( + models, + clients, + shouldBePublic ? modelNames : [], + shouldBePublic ? [] : modelNames); + + private static async Task GenerateAndAssertModels( + InputModelType[] models, + InputClient[] clients, + string[] publicModelNames, + string[] internalModelNames) + { + await GenerateAndAssertFiles( + enums: [], + models: models, + clients: clients, + customFiles: [], + publicModelNames: publicModelNames, + internalModelNames: internalModelNames, + expectedFiles: []); + } + + private static async Task GenerateAndAssertFiles( + InputEnumType[] enums, + InputModelType[] models, + InputClient[] clients, + (string Path, string Content)[] customFiles, + string[] expectedFiles, + string[] unexpectedFiles = null!, + string[] publicModelNames = null!, + string[] internalModelNames = null!, + string[] internalClientNames = null!, + string packageName = "Sample", + Action? configureGenerator = null) + { + publicModelNames ??= []; + internalModelNames ??= []; + internalClientNames ??= []; + unexpectedFiles ??= []; + + var outputPath = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + Directory.CreateDirectory(outputPath); + try + { + foreach (var customFile in customFiles) + { + var customPath = Path.Combine(outputPath, customFile.Path); + Directory.CreateDirectory(Path.GetDirectoryName(customPath)!); + File.WriteAllText(customPath, customFile.Content); + } + + await MockHelpers.LoadMockGeneratorAsync( + inputEnums: () => enums, + inputModels: () => models, + clients: () => clients, + configuration: $$"""{ "package-name": "{{packageName}}", "disable-xml-docs": true }""", + outputPath: outputPath); + configureGenerator?.Invoke(); + + await new CSharpGen().ExecuteAsync(); + + foreach (var modelName in publicModelNames) + { + var modelPath = Path.Combine(outputPath, "src", "Generated", "Models", $"{modelName}.cs"); + Assert.IsTrue(File.Exists(modelPath), $"Expected generated model file '{modelPath}'."); + var text = File.ReadAllText(modelPath); + StringAssert.Contains($"public partial class {modelName}", text, $"{modelName} should be public."); + } + + foreach (var modelName in internalModelNames) + { + var modelPath = Path.Combine(outputPath, "src", "Generated", "Models", $"{modelName}.cs"); + Assert.IsTrue(File.Exists(modelPath), $"Expected generated model file '{modelPath}'."); + var text = File.ReadAllText(modelPath); + StringAssert.Contains($"internal partial class {modelName}", text, $"{modelName} should be internal."); + StringAssert.DoesNotContain($"public partial class {modelName}", text, $"{modelName} should not be public."); + } + + foreach (var clientName in internalClientNames) + { + var clientPath = Path.Combine(outputPath, "src", "Generated", $"{clientName}.cs"); + Assert.IsTrue(File.Exists(clientPath), $"Expected generated client file '{clientPath}'."); + var text = File.ReadAllText(clientPath); + StringAssert.Contains($"internal partial class {clientName}", text, $"{clientName} should be internal."); + StringAssert.DoesNotContain($"public partial class {clientName}", text, $"{clientName} should not be public."); + } + + var modelFactoryPath = Path.Combine(outputPath, "src", "Generated", "SampleModelFactory.cs"); + if (File.Exists(modelFactoryPath)) + { + var modelFactoryText = File.ReadAllText(modelFactoryPath); + foreach (var modelName in publicModelNames) + { + StringAssert.Contains($" {modelName}(", modelFactoryText, $"Model factory method for {modelName} should be generated."); + } + + foreach (var modelName in internalModelNames) + { + StringAssert.DoesNotContain($" {modelName}(", modelFactoryText, $"Model factory method for {modelName} should not be generated."); + } + } + + foreach (var expectedFile in expectedFiles) + { + var filePath = Path.Combine(outputPath, expectedFile); + Assert.IsTrue(File.Exists(filePath), $"Expected generated file '{filePath}'."); + } + + foreach (var unexpectedFile in unexpectedFiles) + { + var filePath = Path.Combine(outputPath, unexpectedFile); + Assert.IsFalse(File.Exists(filePath), $"Did not expect generated file '{filePath}'."); + } + } + finally + { + if (Directory.Exists(outputPath)) + { + Directory.Delete(outputPath, recursive: true); + } + } + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs index fcc90582416..503ed1ad68f 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs @@ -121,31 +121,6 @@ public void JsonModelWriteCore_IsOverride_WhenBaseIsRegularModel() "JsonModelWriteCore should be 'override' with regular base too"); } - [Test] - public void JsonModelWriteCore_IsOverride_WhenBaseProviderIsResolvedAfterSerialization() - { - var baseInputModel = InputFactory.Model("Resource"); - var derivedInputModel = InputFactory.Model("TrackedResource", properties: [InputFactory.Property("Location", InputPrimitiveType.String)]); - MockHelpers.LoadMockGenerator(inputModels: () => [baseInputModel, derivedInputModel]); - - var derived = new DelayedBaseModelProvider(derivedInputModel); - var serialization = new MrwSerializationTypeDefinition(derivedInputModel, derived); - - // The serialization provider can be constructed before later visitors/customization - // resolution make the base model provider available. - derived.BaseModel = new SystemObjectModelProvider(new CSharpType(typeof(object)), baseInputModel); - - var method = serialization.BuildJsonModelWriteCoreMethod(); - - Assert.AreEqual(derived.BaseModel.Type, derived.Type.BaseType, - "The generated model type should inherit the base resolved after serialization construction."); - Assert.AreEqual(derived.BaseModel.Type, serialization.Type.BaseType, - "The serialization type should inherit the same resolved base."); - Assert.IsTrue(method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Override), - "JsonModelWriteCore should evaluate BaseModelProvider when the method is built, not when serialization is constructed"); - Assert.IsFalse(method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Virtual)); - } - // ------------------------------------------------------------------- // PersistableModelWriteCore: 'virtual' with system base, 'override' with regular // (the framework base already implements this; derived model re-introduces it) @@ -353,14 +328,5 @@ FakeMrwBase IPersistableModel.Create(BinaryData data, ModelReaderWr string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; } - - private class DelayedBaseModelProvider(InputModelType inputModel) : ModelProvider(inputModel) - { - public ModelProvider? BaseModel { get; set; } - - protected override ModelProvider? BuildBaseModelProvider() => BaseModel; - - protected override CSharpType? BuildBaseType() => BaseModel?.Type; - } } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs index e8bcb2fa0fb..7e637e07363 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs @@ -738,38 +738,6 @@ public void ValidateGetResponseClassifiersThrowsWhenNoSuccess() Assert.Fail("Expected Exception to be thrown."); } - [Test] - public void TestBuildCreateRequestMethodWithExplodedModelQueryParameter() - { - var filterModel = InputFactory.Model( - "filterOptions", - properties: - [ - InputFactory.Property("field", InputPrimitiveType.String, isRequired: true), - InputFactory.Property("value", InputPrimitiveType.String, isRequired: true), - ]); - var operation = InputFactory.Operation( - "sampleOp", - parameters: [InputFactory.QueryParameter("filter", filterModel, isRequired: true, explode: true)]); - var client = InputFactory.Client( - "TestClient", - methods: [InputFactory.BasicServiceMethod("Test", operation)]); - var clientProvider = new ClientProvider(client); - var restClientProvider = new MockClientProvider(client, clientProvider); - - var method = restClientProvider.Methods.FirstOrDefault(m => m.Signature.Name == "CreateSampleOpRequest"); - Assert.IsNotNull(method); - var body = method!.BodyStatements!.ToDisplayString(); - - // A model-typed query parameter with `explode` is expanded into one query entry per - // property (RFC 6570 form explode) using each property's wire name, instead of serializing - // the whole object via ConvertToString (which produced the type name). - Assert.IsTrue(body.Contains("uri.AppendQuery(\"field\", filter.Field, true);"), body); - Assert.IsTrue(body.Contains("uri.AppendQuery(\"value\", filter.Value, true);"), body); - Assert.IsFalse(body.Contains("AppendQuery(\"filter\""), body); - Assert.IsFalse(body.Contains("ConvertToString(filter)"), body); - } - [Test] public void TestBuildCreateRequestMethodWithQueryParameters() { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs index 9148f659e43..6c6e743ad89 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs @@ -34,7 +34,8 @@ public static async Task> LoadMockGeneratorAsync( Func>? apiVersions = null, string? configuration = null, Func? createCSharpTypeCore = null, - Func? createCSharpTypeCoreFallback = null) + Func? createCSharpTypeCoreFallback = null, + string? outputPath = null) { var mockGenerator = LoadMockGenerator( inputLiterals: inputLiterals, @@ -44,13 +45,13 @@ public static async Task> LoadMockGeneratorAsync( apiVersions: apiVersions, configuration: configuration, createCSharpTypeCore: createCSharpTypeCore, - createCSharpTypeCoreFallback: createCSharpTypeCoreFallback); + createCSharpTypeCoreFallback: createCSharpTypeCoreFallback, + outputPath: outputPath); var compilationResult = compilation == null ? null : await compilation(); var lastContractCompilationResult = lastContractCompilation == null ? null : await lastContractCompilation(); - var sourceInputModel = new Mock(() => new SourceInputModel(compilationResult, lastContractCompilationResult)) { CallBase = true }; - mockGenerator.Setup(p => p.SourceInputModel).Returns(sourceInputModel.Object); + mockGenerator.SetupProperty(p => p.SourceInputModel, new SourceInputModel(compilationResult, lastContractCompilationResult)); return mockGenerator; } @@ -76,7 +77,8 @@ public static Mock LoadMockGenerator( Func? createOutputLibrary = null, bool includeXmlDocs = false, Func? createCSharpTypeCoreFallback = null, - Func? createModelCore = null) + Func? createModelCore = null, + string? outputPath = null) { IReadOnlyList inputNsApiVersions = apiVersions?.Invoke() ?? []; IReadOnlyList inputNsLiterals = inputLiterals?.Invoke() ?? []; @@ -150,7 +152,7 @@ public static Mock LoadMockGenerator( { configuration = "{\"disable-xml-docs\": false, \"package-name\": \"Sample.Namespace\"}"; } - object?[] parameters = [_configFilePath, configuration]; + object?[] parameters = [outputPath ?? _configFilePath, configuration]; var config = loadMethod?.Invoke(null, parameters); var mockGeneratorContext = new Mock(config!); var mockGeneratorInstance = new Mock(mockGeneratorContext.Object) { CallBase = true }; @@ -186,8 +188,7 @@ public static Mock LoadMockGenerator( mockGeneratorInstance.Setup(p => p.OutputLibrary).Returns(createOutputLibrary); } - var sourceInputModel = new Mock(() => new SourceInputModel(null, null)) { CallBase = true }; - mockGeneratorInstance.Setup(p => p.SourceInputModel).Returns(sourceInputModel.Object); + mockGeneratorInstance.SetupProperty(p => p.SourceInputModel, new SourceInputModel(null, null)); codeModelInstance!.SetValue(null, mockGeneratorInstance.Object); clientModelInstance!.SetValue(null, mockGeneratorInstance.Object); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/test/TypeSpecInputConverterTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/test/TypeSpecInputConverterTests.cs index 7e5ec1c3c84..9328dc43571 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/test/TypeSpecInputConverterTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/test/TypeSpecInputConverterTests.cs @@ -564,42 +564,6 @@ public void DeserializeModelWithExternalMetadata() Assert.AreEqual("8.0.0", model.External.MinVersion); } - [Test] - public void DeserializeModelWithExternalUsagePreservesInputAndOutput() - { - // TCGC emits the External usage flag (UsageFlags.External) for models that are also - // referenced by external types. The C# InputModelTypeUsage enum must recognize it so - // that Enum.TryParse does not fail on the unknown token and collapse the whole usage to - // None, which would strip Input/Output and make every property get-only. - var json = @"{ - ""$id"": ""1"", - ""kind"": ""model"", - ""name"": ""TestModel"", - ""namespace"": ""Test.Models"", - ""crossLanguageDefinitionId"": ""Test.Models.TestModel"", - ""usage"": ""Input,Output,External"", - ""properties"": [] - }"; - - var referenceHandler = new TypeSpecReferenceHandler(); - var options = new JsonSerializerOptions - { - AllowTrailingCommas = true, - Converters = - { - new InputTypeConverter(referenceHandler), - new InputModelTypeConverter(referenceHandler), - new InputExternalTypeMetadataConverter() - } - }; - - var model = JsonSerializer.Deserialize(json, options); - Assert.IsNotNull(model); - Assert.IsTrue(model!.Usage.HasFlag(InputModelTypeUsage.Input), "Model should retain Input usage flag"); - Assert.IsTrue(model.Usage.HasFlag(InputModelTypeUsage.Output), "Model should retain Output usage flag"); - Assert.IsTrue(model.Usage.HasFlag(InputModelTypeUsage.External), "Model should have External usage flag"); - } - [Test] public void DeserializeArrayWithExternalMetadata() { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs index c013817a72e..35914c1f8d4 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs @@ -27,12 +27,13 @@ public async Task ExecuteAsync() { CodeModelGenerator.Instance.Emitter.Info("Starting code generation"); CodeModelGenerator.Instance.Stopwatch.Start(); + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); var outputPath = CodeModelGenerator.Instance.Configuration.OutputDirectory; var generatedSourceOutputPath = CodeModelGenerator.Instance.Configuration.ProjectGeneratedDirectory; - // Resolve PackageReference items from the .csproj so custom code referencing - // external NuGet types (e.g., Azure.Storage.Common) compiles correctly. + // Resolve PackageReference items from the .csproj so custom code referencing external + // NuGet types compiles correctly. await GeneratedCodeWorkspace.AddPackageReferencesFromProject(); // Pre-walk the input library and resolve any external types that point at NuGet packages. @@ -90,12 +91,33 @@ await GeneratedCodeWorkspace.LoadBaselineContract(), { // Ensure back-compatibility processing is done after all visitors have run outputType.ProcessTypeForBackCompatibility(); + } + + generatedCodeWorkspace.ApplyPreWriteAccessibility(output.TypeProviders); + generatedCodeWorkspace.AnalyzeProviderReferenceMap(output.TypeProviders); + + foreach (var outputType in output.TypeProviders) + { + if (!ProviderReferenceMapAnalyzer.ShouldWriteProvider(outputType)) + { + continue; + } + + if (outputType is ModelFactoryProvider && outputType.Methods.Count == 0) + { + continue; + } var writer = CodeModelGenerator.Instance.GetWriter(outputType); generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); foreach (var serialization in outputType.SerializationProviders) { + if (!ProviderReferenceMapAnalyzer.ShouldWriteProvider(serialization)) + { + continue; + } + writer = CodeModelGenerator.Instance.GetWriter(serialization); generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); } @@ -104,6 +126,8 @@ await GeneratedCodeWorkspace.LoadBaselineContract(), // Add all the generated files to the workspace await Task.WhenAll(generateFilesTasks); + ProviderReferenceMapAnalyzer.RestorePreWriteModelFactoryMethods(); + LoggingHelpers.LogElapsedTime("All generated types have been written into memory"); // Delete any old generated files @@ -112,14 +136,22 @@ await GeneratedCodeWorkspace.LoadBaselineContract(), LoggingHelpers.LogElapsedTime("All old generated files have been deleted"); await generatedCodeWorkspace.PostProcessAsync(); + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); - // Write the generated files to the output directory + var generatedFiles = new List<(string Name, string Text)>(); await foreach (var file in generatedCodeWorkspace.GetGeneratedFilesAsync()) { if (string.IsNullOrEmpty(file.Text)) { continue; } + + generatedFiles.Add((file.Name, file.Text)); + } + + // Write the generated files to the output directory + foreach (var file in generatedFiles) + { var filename = Path.Combine(outputPath, file.Name); CodeModelGenerator.Instance.Emitter.Info($"Writing {Path.GetFullPath(filename)}"); Directory.CreateDirectory(Path.GetDirectoryName(filename)!); @@ -177,9 +209,10 @@ private static void DeleteDirectory(string path, string[] filesToKeep) return; } + var fileNamesToKeep = filesToKeep.ToHashSet(StringComparer.Ordinal); foreach (var file in directoryInfo.GetFiles("*", SearchOption.AllDirectories)) { - if (!filesToKeep.Contains(file.Name)) + if (!fileNamesToKeep.Contains(file.Name)) { file.Delete(); } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs index e9e016c3554..6df7ccc9758 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs @@ -168,7 +168,7 @@ protected internal virtual void VisitLibrary(OutputLibrary library) /// /// The original . /// Null if it should be removed otherwise the modified version of the . - protected internal virtual ConstructorProvider? VisitConstructor(ConstructorProvider constructor) + protected virtual ConstructorProvider? VisitConstructor(ConstructorProvider constructor) { return constructor; } @@ -302,7 +302,7 @@ protected internal virtual FinallyExpression VisitFinallyExpression(FinallyExpre /// /// The original . /// Null if it should be removed otherwise the modified version of the . - protected internal virtual FieldProvider? VisitField(FieldProvider field) + protected virtual FieldProvider? VisitField(FieldProvider field) { return field; } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index c36686f637f..74f9fc6b971 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -84,6 +84,16 @@ public async Task AddInMemoryFile(TypeProvider type) await UpdateProject(document); } + internal void AnalyzeProviderReferenceMap(IReadOnlyList providers) + { + ProviderReferenceMapAnalyzer.Analyze(providers); + } + + internal void ApplyPreWriteAccessibility(IReadOnlyList providers) + { + ProviderReferenceMapAnalyzer.ApplyPreWriteAccessibility(providers); + } + private async Task UpdateProject(Document document) { var root = await document.GetSyntaxRootAsync(); @@ -278,10 +288,8 @@ public async Task PostProcessAsync() case Configuration.UnreferencedTypesHandlingOption.KeepAll: break; case Configuration.UnreferencedTypesHandlingOption.Internalize: - _project = await postProcessor.InternalizeAsync(_project); break; case Configuration.UnreferencedTypesHandlingOption.RemoveOrInternalize: - _project = await postProcessor.InternalizeAsync(_project); _project = await postProcessor.RemoveAsync(_project); break; } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs index dc42f801732..bef5dce85b4 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs @@ -9,7 +9,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Simplification; namespace Microsoft.TypeSpec.Generator { @@ -113,58 +112,6 @@ private async Task GetTypeSymbolsAsync(Compilation compilation, protected virtual bool ShouldIncludeDocument(Document document) => !GeneratedCodeWorkspace.IsGeneratedTestDocument(document); - /// - /// This method marks the "not publicly" referenced types as internal if they are previously defined as public. It will do this job in the following steps: - /// 1. This method will read all the public types defined in the given , and build a cache for those symbols - /// 2. Build a public reference map for those symbols - /// 3. Finds all the root symbols, please override the to control which document you would like to include - /// 4. Visit all the symbols starting from the root symbols following the reference map to get all unvisited symbols - /// 5. Change the accessibility of the unvisited symbols in step 4 to internal - /// - /// The project to process - /// The processed . is immutable, therefore this should usually be a new instance - public async Task InternalizeAsync(Project project) - { - var compilation = await project.GetCompilationAsync(); - if (compilation == null) - { - return project; - } - - // first get all the declared symbols - var definitions = await GetTypeSymbolsAsync(compilation, project, true); - // build the reference map - var referenceMap = - await new ReferenceMapBuilder(compilation, project).BuildPublicReferenceMapAsync( - definitions.DeclaredSymbols, definitions.DeclaredNodesCache); - // get the root symbols - var rootSymbols = await GetRootSymbolsAsync(project, definitions); - // traverse all the root and recursively add all the things we met - var publicSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); - - var symbolsToInternalize = definitions.DeclaredSymbols.Except(publicSymbols); - - var nodesToInternalize = new Dictionary(); - foreach (var symbol in symbolsToInternalize) - { - foreach (var node in definitions.DeclaredNodesCache[symbol]) - { - nodesToInternalize[node] = project.GetDocumentId(node.SyntaxTree)!; - } - } - - foreach (var (model, documentId) in nodesToInternalize) - { - project = MarkInternal(project, model, documentId); - } - - var modelNamesToRemove = - nodesToInternalize.Keys.Select(item => item.Identifier.Text); - project = await RemoveMethodsFromModelFactoryAsync(project, definitions, modelNamesToRemove.ToHashSet()); - - return project; - } - private async Task RemoveMethodsFromModelFactoryAsync(Project project, TypeSymbols definitions, HashSet namesToRemove) @@ -246,25 +193,32 @@ public async Task RemoveAsync(Project project) // find all the declarations, including non-public declared var definitions = await GetTypeSymbolsAsync(compilation, project, false); - // build reference map - var referenceMap = - await new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( - definitions.DeclaredSymbols, definitions.DocumentsCache); - // get root symbols - var rootSymbols = await GetRootSymbolsAsync(project, definitions); - // include model factory as a root symbol when doing the remove pass so that we are sure to include any internal - // helpers that are required by the model factory. - if (_modelFactorySymbol != null) + IEnumerable symbolsToRemove; + HashSet referencedSet; + if (ProviderReferenceMapAnalyzer.LatestResult is { } referenceMapResult) { - rootSymbols.Add(_modelFactorySymbol); + // The remove pass uses the same precomputed hybrid map to avoid scanning all generated + // documents with Roslyn while preserving custom-code references as roots. + symbolsToRemove = GetSymbolsByName(definitions.DeclaredSymbols, referenceMapResult.RemoveCandidates).ToArray(); + referencedSet = new HashSet(definitions.DeclaredSymbols.Except(symbolsToRemove), SymbolEqualityComparer.Default); } - // traverse the map to determine the declarations that we are about to remove, starting from root nodes - var referencedSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); - - referencedSymbols = AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols); - var referencedSet = new HashSet(referencedSymbols, SymbolEqualityComparer.Default); + else + { + var referenceMap = await new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( + definitions.DeclaredSymbols, definitions.DocumentsCache); + // Include model factory as a root symbol when doing the remove pass so that we are sure to include any internal + // helpers that are required by the model factory. + var rootSymbols = await GetRootSymbolsAsync(project, definitions); + if (_modelFactorySymbol != null) + { + rootSymbols.Add(_modelFactorySymbol); + } - var symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); + var referencedSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); + referencedSymbols = AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols); + referencedSet = new HashSet(referencedSymbols, SymbolEqualityComparer.Default); + symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); + } var nodesToRemove = new List(); foreach (var symbol in symbolsToRemove) @@ -276,6 +230,14 @@ public async Task RemoveAsync(Project project) nodesToRemove.AddRange(definitions.DeclaredNodesCache[symbol]); } + var modelNamesToRemove = nodesToRemove + .Select(static item => item.Identifier.Text) + .ToHashSet(StringComparer.Ordinal); + if (modelNamesToRemove.Count > 0) + { + project = await RemoveMethodsFromModelFactoryAsync(project, definitions, modelNamesToRemove); + } + // remove them one by one project = await RemoveModelsAsync(project, nodesToRemove); @@ -352,18 +314,19 @@ private static IEnumerable GetReferencedTypes(T definition, return Enumerable.Empty(); } - private Project MarkInternal(Project project, BaseTypeDeclarationSyntax declarationNode, DocumentId documentId) + private static IEnumerable GetSymbolsByName(IEnumerable symbols, HashSet names) { - var newNode = ChangeModifier(declarationNode, SyntaxKind.PublicKeyword, SyntaxKind.InternalKeyword); - var tree = declarationNode.SyntaxTree; - var document = project.GetDocument(documentId)!; - var newRoot = tree.GetRoot().ReplaceNode(declarationNode, newNode) - .WithAdditionalAnnotations(Simplifier.Annotation); - document = document.WithSyntaxRoot(newRoot); - return document.Project; + foreach (var symbol in symbols) + { + if (names.Contains(symbol.GetFullyQualifiedName())) + { + yield return symbol; + } + } } - private async Task RemoveModelsAsync(Project project, + private async Task RemoveModelsAsync( + Project project, IEnumerable unusedModels) { // accumulate the definitions from the same document together @@ -392,24 +355,6 @@ private async Task RemoveModelsAsync(Project project, return project; } - private static BaseTypeDeclarationSyntax ChangeModifier(BaseTypeDeclarationSyntax memberDeclaration, - SyntaxKind from, - SyntaxKind to) - { - var originalTokenInList = memberDeclaration.Modifiers.FirstOrDefault(token => token.IsKind(from)); - - // skip this if there is nothing to replace - if (originalTokenInList == default) - { - return memberDeclaration; - } - - var newToken = - SyntaxFactory.Token(originalTokenInList.LeadingTrivia, to, originalTokenInList.TrailingTrivia); - var newModifiers = memberDeclaration.Modifiers.Replace(originalTokenInList, newToken); - return memberDeclaration.WithModifiers(newModifiers); - } - private async Task RemoveModelsFromDocumentAsync(Project project, IEnumerable models) { @@ -479,7 +424,14 @@ private async Task RemoveInvalidUsings(Solution solution, DocumentId d if (invalidUsings.Count > 0) { + var leadingTrivia = invalidUsings[0].GetLeadingTrivia(); cu = cu.RemoveNodes(invalidUsings, SyntaxRemoveOptions.KeepNoTrivia)!; + if (leadingTrivia.Count > 0) + { + var firstToken = cu.GetFirstToken(includeZeroWidth: true); + cu = cu.ReplaceToken(firstToken, firstToken.WithLeadingTrivia(leadingTrivia.AddRange(firstToken.LeadingTrivia))); + } + solution = solution.WithDocumentSyntaxRoot(documentId, cu); } @@ -497,30 +449,37 @@ private async Task RemoveInvalidAttributes(Solution solution, Document return solution; } - var attributes = cu.DescendantNodes().OfType(); - var firstAttribute = attributes.FirstOrDefault(); + var attributeLists = cu.DescendantNodes().OfType().ToArray(); + var firstAttributeList = attributeLists.FirstOrDefault(); - var invalidAttributes = attributes - .Where(attr => attr.Attributes.Any(attribute => + var invalidAttributes = attributeLists + .SelectMany(static attr => attr.Attributes) + .Where(attribute => attribute.ArgumentList?.Arguments.Any(arg => arg.Expression is TypeOfExpressionSyntax typeOfExpr && - model.GetTypeInfo(typeOfExpr.Type).Type?.TypeKind == TypeKind.Error) == true)) + model.GetTypeInfo(typeOfExpr.Type).Type?.TypeKind == TypeKind.Error) == true) .ToHashSet(); if (invalidAttributes.Count > 0) { + var firstAttributeListRemoved = firstAttributeList != null && + firstAttributeList.Attributes.All(invalidAttributes.Contains); + var leadingTrivia = firstAttributeList?.GetLeadingTrivia(); cu = cu.RemoveNodes(invalidAttributes, SyntaxRemoveOptions.KeepNoTrivia)!; + var emptyAttributeLists = cu.DescendantNodes().OfType() + .Where(static list => list.Attributes.Count == 0) + .ToArray(); + cu = cu.RemoveNodes(emptyAttributeLists, SyntaxRemoveOptions.KeepNoTrivia)!; - if (invalidAttributes.Contains(firstAttribute!)) + if (firstAttributeListRemoved && leadingTrivia != null) { - var leadingTrivia = firstAttribute!.GetLeadingTrivia(); // Find where XML docs end and indentation begins var xmlDocTrivia = new List(); var lastXmlIndex = -1; - for (int i = 0; i < leadingTrivia.Count; i++) + for (int i = 0; i < leadingTrivia.Value.Count; i++) { - var trivia = leadingTrivia[i]; + var trivia = leadingTrivia.Value[i]; if (trivia.IsKind(SyntaxKind.SingleLineDocumentationCommentTrivia)) { lastXmlIndex = i; @@ -532,14 +491,14 @@ arg.Expression is TypeOfExpressionSyntax typeOfExpr && { for (int i = 0; i <= lastXmlIndex; i++) { - xmlDocTrivia.Add(leadingTrivia[i]); + xmlDocTrivia.Add(leadingTrivia.Value[i]); } // Include the newline after the last XML doc if present - if (lastXmlIndex + 1 < leadingTrivia.Count && - leadingTrivia[lastXmlIndex + 1].IsKind(SyntaxKind.EndOfLineTrivia)) + if (lastXmlIndex + 1 < leadingTrivia.Value.Count && + leadingTrivia.Value[lastXmlIndex + 1].IsKind(SyntaxKind.EndOfLineTrivia)) { - xmlDocTrivia.Add(leadingTrivia[lastXmlIndex + 1]); + xmlDocTrivia.Add(leadingTrivia.Value[lastXmlIndex + 1]); } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs new file mode 100644 index 00000000000..1058ba9d66b --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs @@ -0,0 +1,2218 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text.RegularExpressions; +using Microsoft.TypeSpec.Generator.Expressions; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Statements; + +namespace Microsoft.TypeSpec.Generator +{ + internal static class ProviderReferenceMapAnalyzer + { + private static ProviderReferenceMapResult? _latestResult; + private static readonly ConditionalWeakTable, Dictionary> _simpleNameLookupCache = new(); + private static TypeProvider? _preWriteModelFactory; + private static MethodProvider[]? _preWriteModelFactoryMethods; + + public static ProviderReferenceMapResult? LatestResult => _latestResult; + public static bool PreWriteAccessibilityApplied { get; private set; } + + public static bool ShouldWriteProvider(TypeProvider provider) => + _latestResult?.RemoveCandidates.Contains(GetProviderTypeName(provider.Type)) != true; + + public static void ResetPreWriteAccessibility() + { + RestorePreWriteModelFactoryMethods(); + _latestResult = null; + PreWriteAccessibilityApplied = false; + } + + public static void ApplyPreWriteAccessibility(IReadOnlyList providers) + { + PreWriteAccessibilityApplied = false; + if (Configuration.UnreferencedTypesHandling == Configuration.UnreferencedTypesHandlingOption.KeepAll) + { + return; + } + + var (internalizeCandidates, publicizeCandidates) = GetPreWriteAccessibilityCandidates(providers); + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + if (internalizeCandidates.Contains(providerName)) + { + provider.PreserveXmlDocs(); + provider.Update(modifiers: MakeInternal(provider.DeclarationModifiers)); + } + else if (publicizeCandidates.Contains(providerName)) + { + provider.Update(modifiers: MakePublic(provider.DeclarationModifiers)); + } + } + + RemoveMethodsFromModelFactory(GetSimpleNames(internalizeCandidates)); + PreWriteAccessibilityApplied = true; + } + + public static void RestorePreWriteModelFactoryMethods() + { + if (_preWriteModelFactory == null || _preWriteModelFactoryMethods == null) + { + return; + } + + _preWriteModelFactory.Update(methods: _preWriteModelFactoryMethods); + _preWriteModelFactory = null; + _preWriteModelFactoryMethods = null; + } + + public static void Analyze(IReadOnlyList providers) + { + var generatedProviders = GetGeneratedProviders(providers); + var graph = BuildGraph(generatedProviders); + var publicGraph = BuildGraph(generatedProviders, publicOnly: true); + + var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(generatedProviders, graph.Nodes); + var apiBaselineGeneratedTypeRoots = GetApiBaselineGeneratedTypeRoots(graph.Nodes); + customPublicRoots.UnionWith(apiBaselineGeneratedTypeRoots); + var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarations(generatedProviders, graph.Nodes); + customPublicRoots.UnionWith(generatedPublicDeclarations); + var customCodeRemovalRoots = GetCustomCodeGeneratedTypeRoots(generatedProviders, graph.Nodes); + var customRemovalRoots = new HashSet(customCodeRemovalRoots, StringComparer.Ordinal); + customRemovalRoots.UnionWith(apiBaselineGeneratedTypeRoots); + customRemovalRoots.UnionWith(generatedPublicDeclarations); + var customInternalDeclarations = GetCustomCodeInternalGeneratedTypeDeclarations(generatedProviders, graph.Nodes); + var generatedInternalDeclarations = GetGeneratedInternalTypeDeclarations(generatedProviders, graph.Nodes); + + // Helper types are rooted after an initial reachability pass so unused infrastructure + // such as change-tracking dictionaries can still be removed when no reachable type needs them. + var generatedDiscriminatorBaseNames = GetGeneratedPersistableModelProxyTypeNames(generatedProviders, publicGraph.Nodes); + var (internalizeCandidates, publicizeCandidates, _) = GetAccessibilityCandidates( + providers, + generatedProviders, + graph, + publicGraph, + customPublicRoots, + customInternalDeclarations, + generatedInternalDeclarations, + generatedDiscriminatorBaseNames); + + // Body-only generated dependencies are needed to avoid deleting helper files, but they do + // not contribute to public API reachability for internalization. + AddGeneratedBodyReferences(providers, graph); + var removeCandidates = GetRemovalCandidates( + providers, + generatedProviders, + graph, + customRemovalRoots, + generatedDiscriminatorBaseNames); + + _latestResult = new ProviderReferenceMapResult( + internalizeCandidates, + publicizeCandidates, + removeCandidates); + RemoveMethodsFromModelFactory(GetSimpleNames(removeCandidates)); + } + + private static (HashSet InternalizeCandidates, HashSet PublicizeCandidates) GetPreWriteAccessibilityCandidates(IReadOnlyList providers) + { + var generatedProviders = GetGeneratedProviders(providers); + var graph = BuildGraph(generatedProviders); + var publicGraph = BuildGraph(generatedProviders, publicOnly: true); + var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(generatedProviders, graph.Nodes); + var apiBaselineGeneratedTypeRoots = GetApiBaselineGeneratedTypeRoots(graph.Nodes); + customPublicRoots.UnionWith(apiBaselineGeneratedTypeRoots); + var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarations(generatedProviders, graph.Nodes); + customPublicRoots.UnionWith(generatedPublicDeclarations); + var customInternalDeclarations = GetCustomCodeInternalGeneratedTypeDeclarations(generatedProviders, graph.Nodes); + var generatedInternalDeclarations = GetGeneratedInternalTypeDeclarations(generatedProviders, graph.Nodes); + var generatedDiscriminatorBaseNames = new HashSet(StringComparer.Ordinal); + + var (internalizeCandidates, publicizeCandidates, _) = GetAccessibilityCandidates( + providers, + generatedProviders, + graph, + publicGraph, + customPublicRoots, + customInternalDeclarations, + generatedInternalDeclarations, + generatedDiscriminatorBaseNames); + + return (internalizeCandidates, publicizeCandidates); + } + + private static (HashSet InternalizeCandidates, HashSet PublicizeCandidates, HashSet InternalizeHelperRoots) GetAccessibilityCandidates( + IReadOnlyList providers, + IReadOnlyList generatedProviders, + ProviderReferenceGraph graph, + ProviderReferenceGraph publicGraph, + HashSet customPublicRoots, + HashSet customInternalDeclarations, + HashSet generatedInternalDeclarations, + HashSet generatedDiscriminatorBaseNames) + { + var internalizeReferences = CloneReferences(publicGraph.References); + var internalizeRoots = GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: false, includeAdditionalRoots: true, includeUnionVariantRoots: false, publicClientRootsOnly: true); + if (ShouldUseUnionVariantFallbackRoots()) + { + AddUnionVariantRoots(internalizeRoots, providers, graph.Nodes); + } + + var generatedPublicReachable = GetReachableTypes(internalizeRoots, internalizeReferences); + AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, generatedPublicReachable, generatedDiscriminatorBaseNames); + internalizeRoots.UnionWith(customPublicRoots); + var internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); + AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, internalizeReachableWithoutHelpers, generatedDiscriminatorBaseNames); + internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); + var publicizeRoots = new HashSet(internalizeRoots, StringComparer.Ordinal); + var internalizeHelperRoots = GetHelperRootNames(generatedProviders, graph.Nodes, internalizeReachableWithoutHelpers); + internalizeRoots.UnionWith(internalizeHelperRoots); + var internalizeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, graph.Nodes, publicOnly: true); + var customInternalBoundaryNodes = GetCustomInternalBoundaryNodes(publicGraph, customInternalDeclarations); + var publicizeDeclaredNodes = GetPublicizeDeclaredNodes(generatedProviders, graph.Nodes, internalizeDeclaredNodes); + var generatedImplementationInternalDeclarations = GetGeneratedImplementationInternalTypeDeclarations(generatedInternalDeclarations); + var publicApiTraversalNodes = GetPublicApiTraversalNodes( + internalizeDeclaredNodes, + publicizeDeclaredNodes, + generatedInternalDeclarations, + generatedImplementationInternalDeclarations); + var publicizeReachable = GetReachableTypes(publicizeRoots, internalizeReferences, publicApiTraversalNodes); + var internalizeCandidates = GetInternalizeCandidates( + internalizeDeclaredNodes, + publicizeReachable, + customInternalDeclarations, + customInternalBoundaryNodes, + publicizeRoots); + var publicizeRootExclusions = GetRootNames( + providers, + graph.Nodes, + helperRoots: [], + includeModelFactory: true, + includeAdditionalRoots: true, + includeUnionVariantRoots: true, + publicClientRootsOnly: true); + var publicizeCandidates = GetPublicizeCandidates( + publicizeDeclaredNodes, + publicizeReachable, + customInternalDeclarations, + customInternalBoundaryNodes, + internalizeHelperRoots, + publicizeRootExclusions, + generatedInternalDeclarations, + publicizeRoots, + internalizeReferences, + generatedImplementationInternalDeclarations); + + return (internalizeCandidates, publicizeCandidates, internalizeHelperRoots); + } + + private static HashSet GetCustomInternalBoundaryNodes( + ProviderReferenceGraph publicGraph, + HashSet customInternalDeclarations) + { + var boundaryNodes = new HashSet(StringComparer.Ordinal); + foreach (var node in publicGraph.Nodes) + { + if (!publicGraph.References.TryGetValue(node, out var references)) + { + continue; + } + + if (references.Overlaps(customInternalDeclarations)) + { + boundaryNodes.Add(node); + } + } + + return boundaryNodes; + } + + private static HashSet GetPublicizeDeclaredNodes( + IReadOnlyList generatedProviders, + HashSet nodes, + HashSet internalizeDeclaredNodes) + { + var publicizeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, nodes, publicOnly: false); + publicizeDeclaredNodes.ExceptWith(internalizeDeclaredNodes); + return publicizeDeclaredNodes; + } + + private static HashSet GetPublicApiTraversalNodes( + HashSet internalizeDeclaredNodes, + HashSet publicizeDeclaredNodes, + HashSet generatedInternalDeclarations, + HashSet generatedImplementationInternalDeclarations) + { + var traversalNodes = new HashSet(StringComparer.Ordinal); + foreach (var node in internalizeDeclaredNodes) + { + if (generatedInternalDeclarations.Contains(node) || + generatedImplementationInternalDeclarations.Contains(node)) + { + continue; + } + + traversalNodes.Add(node); + } + + foreach (var node in publicizeDeclaredNodes) + { + if (!generatedImplementationInternalDeclarations.Contains(node)) + { + traversalNodes.Add(node); + } + } + + return traversalNodes; + } + + private static HashSet GetInternalizeCandidates( + HashSet internalizeDeclaredNodes, + HashSet publicizeReachable, + HashSet customInternalDeclarations, + HashSet customInternalBoundaryNodes, + HashSet publicizeRoots) + { + var candidates = new HashSet(StringComparer.Ordinal); + foreach (var node in internalizeDeclaredNodes) + { + if (!publicizeReachable.Contains(node) || + customInternalDeclarations.Contains(node) || + customInternalBoundaryNodes.Contains(node) && !publicizeRoots.Contains(node)) + { + candidates.Add(node); + } + } + + return candidates; + } + + private static HashSet GetPublicizeCandidates( + HashSet publicizeDeclaredNodes, + HashSet publicizeReachable, + HashSet customInternalDeclarations, + HashSet customInternalBoundaryNodes, + HashSet internalizeHelperRoots, + HashSet publicizeRootExclusions, + HashSet generatedInternalDeclarations, + HashSet publicizeRoots, + Dictionary> internalizeReferences, + HashSet generatedImplementationInternalDeclarations) + { + var candidates = new HashSet(StringComparer.Ordinal); + foreach (var node in publicizeDeclaredNodes) + { + if (customInternalDeclarations.Contains(node) || + customInternalBoundaryNodes.Contains(node) || + internalizeHelperRoots.Contains(node) || + publicizeRootExclusions.Contains(node) || + !publicizeReachable.Contains(node)) + { + continue; + } + + if (generatedInternalDeclarations.Contains(node) && !publicizeRoots.Contains(node)) + { + continue; + } + + if (!publicizeRoots.Contains(node) && + !HasPublicApiPredecessor(node, internalizeReferences, publicizeReachable, generatedImplementationInternalDeclarations)) + { + continue; + } + + candidates.Add(node); + } + + return candidates; + } + + private static HashSet GetRemovalCandidates( + IReadOnlyList providers, + IReadOnlyList generatedProviders, + ProviderReferenceGraph graph, + HashSet customRemovalRoots, + HashSet generatedDiscriminatorBaseNames) + { + var removeRoots = GetRootNames( + providers, + graph.Nodes, + helperRoots: [], + includeModelFactory: true, + includeAdditionalRoots: true, + includeUnionVariantRoots: false, + publicClientRootsOnly: false); + + removeRoots.UnionWith(customRemovalRoots); + AddMatchingNamesWithSimpleNameSuffix(removeRoots, "ReferenceType", graph.Nodes); + AddCustomCodeExtensionRoots(removeRoots, generatedProviders, graph.Nodes); + AddCustomizationBackedExtensionRoots(removeRoots, graph.Nodes); + AddCustomRequestHeaderExtensionsRoot(removeRoots, generatedProviders, graph.Nodes); + RemoveUnusedRequestHeaderExtensionsRoot(removeRoots, graph.References, providers); + + var removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); + AddDerivedModelReferences(providers, graph.Nodes, graph.References, removeReachableWithoutHelpers, generatedDiscriminatorBaseNames); + removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); + AddBasePreservedReferences(generatedProviders, graph.Nodes, graph.References, removeReachableWithoutHelpers); + + var removeHelperRoots = GetHelperRootNames(generatedProviders, graph.Nodes, removeReachableWithoutHelpers, graph.References); + removeRoots.UnionWith(removeHelperRoots); + + var removeReachable = GetReachableTypes(removeRoots, graph.References); + AddBasePreservedReferences(generatedProviders, graph.Nodes, graph.References, removeReachable); + + var removeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, graph.Nodes, publicOnly: false); + removeDeclaredNodes.ExceptWith(removeReachable); + return removeDeclaredNodes; + } + + private static HashSet GetCustomCodeGeneratedTypeRoots(IReadOnlyList providers, HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + AddCustomCodeViewRoots(roots, customCodeView, generatedTypeNames, publicOnly: false); + } + + return roots; + } + + private static HashSet GetCustomCodePublicGeneratedTypeRoots(IReadOnlyList providers, HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + if (!customCodeView.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) + { + AddMatchingName(roots, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); + } + + AddCustomCodeViewRoots(roots, customCodeView, generatedTypeNames, publicOnly: true); + } + + return roots; + } + + private static IEnumerable GetCustomCodeViews(IReadOnlyList providers) + { + var visited = new HashSet(StringComparer.Ordinal); + var modelFactoryCustomCodeView = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value.CustomCodeView; + if (modelFactoryCustomCodeView != null && visited.Add(GetCustomCodeViewIdentity(modelFactoryCustomCodeView))) + { + yield return modelFactoryCustomCodeView; + } + + foreach (var provider in providers) + { + var customCodeView = provider.CustomCodeView; + if (customCodeView == null || !visited.Add(GetCustomCodeViewIdentity(customCodeView))) + { + continue; + } + + yield return customCodeView; + } + + foreach (var customTypeProvider in CodeModelGenerator.Instance.SourceInputModel.GetCustomizationTypeProviders()) + { + if (visited.Add(GetCustomCodeViewIdentity(customTypeProvider))) + { + yield return customTypeProvider; + } + } + } + + private static string GetCustomCodeViewIdentity(TypeProvider customCodeView) => + customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider + ? namedTypeSymbolProvider.MetadataName + : GetProviderTypeName(customCodeView.Type); + + private static void AddCustomRequestHeaderExtensionsRoot(HashSet roots, IReadOnlyList providers, HashSet nodes) + { + // TODO: Resolve body-level SetDelimited extension calls to PipelineRequestHeadersExtensions so this can be a normal type edge. + if (!HasCustomRequestHeaderExtensionsReference(providers)) + { + return; + } + + AddMatchingNamesWithSimpleNameSuffix(roots, "RequestHeaderExtensions", nodes); + AddMatchingNamesWithSimpleNameSuffix(roots, "RequestHeadersExtensions", nodes); + } + + private static void AddCustomCodeExtensionRoots(HashSet roots, IReadOnlyList providers, HashSet nodes) + { + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + AddMatchingName(roots, $"{GetCustomCodeViewSimpleName(customCodeView)}Extensions", nodes); + } + } + + private static string GetCustomCodeViewSimpleName(TypeProvider customCodeView) => + customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider + ? namedTypeSymbolProvider.MetadataSimpleName + : customCodeView.Type.Name; + + private static void AddCustomizationBackedExtensionRoots(HashSet roots, HashSet nodes) + { + foreach (var node in nodes) + { + var simpleName = GetSimpleName(node); + if (!simpleName.EndsWith("Extensions", StringComparison.Ordinal)) + { + continue; + } + + var namespaceName = GetNamespaceName(node); + if (namespaceName == null) + { + continue; + } + + var customTypeName = simpleName.Substring(0, simpleName.Length - "Extensions".Length); + if (CodeModelGenerator.Instance.SourceInputModel.FindForTypeInCustomization(namespaceName, customTypeName) != null) + { + roots.Add(node); + } + } + } + + private static void AddCustomCodeViewRoots(HashSet roots, TypeProvider customCodeView, HashSet generatedTypeNames, bool publicOnly) + { + AddTypeReference(roots, customCodeView.BaseType, generatedTypeNames); + AddProviderBodyDependencyTypes(roots, customCodeView.SignatureDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true); + if (!publicOnly) + { + AddAttributes(roots, customCodeView.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); + AddMatchingName(roots, $"{GetCustomCodeViewSimpleName(customCodeView)}Extensions", generatedTypeNames); + } + + foreach (var implementedType in customCodeView.Implements) + { + AddTypeReference(roots, implementedType, generatedTypeNames); + } + + foreach (var constructor in customCodeView.Constructors) + { + if (publicOnly && !IsPublic(constructor.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(roots, constructor.Signature, generatedTypeNames, serializationProviderNamesByType: null, includeAttributes: !publicOnly); + } + + foreach (var method in customCodeView.Methods) + { + if (publicOnly && !IsPublic(method.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(roots, method.Signature, generatedTypeNames, serializationProviderNamesByType: null, includeAttributes: !publicOnly); + } + + foreach (var property in customCodeView.Properties) + { + if (publicOnly && !IsPublic(property.Modifiers)) + { + continue; + } + + AddTypeReference(roots, property.Type, generatedTypeNames); + AddTypeReference(roots, property.ExplicitInterface, generatedTypeNames); + if (!publicOnly) + { + AddAttributes(roots, property.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); + } + } + + foreach (var field in customCodeView.Fields) + { + if (publicOnly && !IsPublic(field.Modifiers)) + { + continue; + } + + AddTypeReference(roots, field.Type, generatedTypeNames); + if (!publicOnly) + { + AddAttributes(roots, field.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); + } + } + } + + private static HashSet GetApiBaselineGeneratedTypeRoots(HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + var projectDirectory = CodeModelGenerator.Instance.Configuration.ProjectDirectory; + if (string.IsNullOrEmpty(projectDirectory)) + { + return roots; + } + + var apiDirectory = Path.GetFullPath(Path.Combine(projectDirectory, "..", "api")); + if (!Directory.Exists(apiDirectory)) + { + return roots; + } + + var apiText = string.Join("\n", Directory.GetFiles(apiDirectory, "*.cs", SearchOption.AllDirectories).Select(File.ReadAllText)); + var apiDeclaredTypeNames = GetApiDeclaredTypeNames(apiText); + foreach (var fullName in generatedTypeNames) + { + var simpleName = StripGenericArity(GetSimpleName(fullName)); + var normalizedFullName = StripGenericArity(fullName); + if (!ContainsApiTypeReference(apiText, apiDeclaredTypeNames, normalizedFullName, simpleName)) + { + continue; + } + + roots.Add(fullName); + } + + return roots; + } + + private static HashSet GetApiDeclaredTypeNames(string apiText) + { + var declaredTypeNames = new HashSet(StringComparer.Ordinal); + string? currentNamespace = null; + foreach (var line in apiText.Split('\n')) + { + var namespaceMatch = Regex.Match(line, @"^namespace\s+([\w.]+)\s*\{?\s*$"); + if (namespaceMatch.Success) + { + currentNamespace = namespaceMatch.Groups[1].Value; + continue; + } + + if (currentNamespace == null) + { + continue; + } + + var declarationMatch = Regex.Match(line, @"^ \S.*?\b(class|struct|interface|enum)\s+([A-Za-z_][A-Za-z0-9_]*)(?!\s*<)(?!\w)"); + if (declarationMatch.Success) + { + declaredTypeNames.Add($"{currentNamespace}.{declarationMatch.Groups[2].Value}"); + } + } + + return declaredTypeNames; + } + + private static bool ContainsApiTypeReference(string apiText, HashSet apiDeclaredTypeNames, string fullName, string simpleName) + { + var fullNamePattern = $@"(? GetCustomCodeInternalGeneratedTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) + { + var declarations = new HashSet(StringComparer.Ordinal); + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + if (!customCodeView.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal)) + { + continue; + } + + if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) + { + AddMatchingName(declarations, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); + } + else + { + AddTypeReference(declarations, customCodeView.Type, generatedTypeNames); + } + } + + return declarations; + } + + private static HashSet GetGeneratedPersistableModelProxyTypeNames(IReadOnlyList providers, HashSet generatedTypeNames) + { + var proxyTypes = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (provider.Attributes.Any(static attribute => IsAttributeNamed(attribute, "PersistableModelProxy"))) + { + AddTypeReference(proxyTypes, provider.Type, generatedTypeNames); + } + } + + return proxyTypes; + } + + private static HashSet GetGeneratedInternalTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) + => GetGeneratedTypeDeclarationsByLastContractAccessibility(providers, generatedTypeNames, TypeSignatureModifiers.Internal); + + private static HashSet GetGeneratedPublicTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) + => GetGeneratedTypeDeclarationsByLastContractAccessibility(providers, generatedTypeNames, TypeSignatureModifiers.Public); + + private static HashSet GetGeneratedTypeDeclarationsByLastContractAccessibility( + IReadOnlyList providers, + HashSet generatedTypeNames, + TypeSignatureModifiers accessibility) + { + var declarations = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (provider.LastContractView?.DeclarationModifiers.HasFlag(accessibility) != true) + { + continue; + } + + AddTypeReference(declarations, provider.Type, generatedTypeNames); + } + + return declarations; + } + + private static HashSet GetGeneratedImplementationInternalTypeDeclarations(HashSet generatedInternalDeclarations) + { + var implementationDeclarations = new HashSet(StringComparer.Ordinal); + foreach (var name in generatedInternalDeclarations) + { + if (GetSimpleName(name).StartsWith("Internal", StringComparison.Ordinal)) + { + implementationDeclarations.Add(name); + } + } + + return implementationDeclarations; + } + + private static HashSet GetSimpleNames(HashSet names) + { + var simpleNames = new HashSet(StringComparer.Ordinal); + foreach (var name in names) + { + simpleNames.Add(GetSimpleName(name)); + } + + return simpleNames; + } + + private static ProviderReferenceGraph BuildGraph(IReadOnlyList generatedProviders, bool publicOnly = false) + { + // Each generated provider becomes a node, and provider metadata supplies the edges: + // inheritance, signatures, properties, fields, nested/serialization providers, attributes, + // and selected implementation dependencies. This avoids parsing generated C# just to + // rediscover generated-to-generated references. + var serializationProviderNamesByType = GetSerializationProviderNamesByType(generatedProviders); + IReadOnlyDictionary? serializationReferenceNamesByType = publicOnly ? null : serializationProviderNamesByType; + var nodes = new HashSet(StringComparer.Ordinal); + var references = new Dictionary>(StringComparer.Ordinal); + foreach (var provider in generatedProviders) + { + var providerName = GetProviderTypeName(provider.Type); + if (nodes.Add(providerName)) + { + references.Add(providerName, new HashSet(StringComparer.Ordinal)); + } + } + + foreach (var provider in generatedProviders) + { + var current = GetProviderTypeName(provider.Type); + AddTypeReference(references[current], provider.Type, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], provider.BaseType, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], provider.DeclaringTypeProvider?.Type, nodes, serializationReferenceNamesByType); + + if (IsKept(provider.Type, CodeModelGenerator.Instance.NonRootTypes, nodes)) + { + continue; + } + + // Model factory signatures mention many models. The existing Roslyn post-processor + // removes factory methods for unreachable models, so model factory should only + // contribute helper dependencies, not model reachability edges. + if (IsModelFactoryProvider(provider)) + { + continue; + } + + foreach (var implementedType in provider.Implements) + { + AddTypeReference(references[current], implementedType, nodes, serializationReferenceNamesByType); + } + + if (!publicOnly) + { + foreach (var nestedType in provider.NestedTypes) + { + AddTypeReference(references[current], nestedType.Type, nodes, serializationReferenceNamesByType); + } + } + + if (!publicOnly) + { + foreach (var serializationProvider in provider.SerializationProviders) + { + AddTypeReference(references[current], serializationProvider.Type, nodes, serializationReferenceNamesByType); + } + } + + foreach (var property in provider.Properties) + { + if (publicOnly && !IsPublic(property.Modifiers)) + { + continue; + } + + AddTypeReference(references[current], property.Type, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], property.ExplicitInterface, nodes, serializationReferenceNamesByType); + if (!publicOnly) + { + AddAttributes(references[current], property.Attributes, nodes, serializationReferenceNamesByType, includeArguments: false); + } + } + + foreach (var field in provider.Fields) + { + if (publicOnly && !field.Modifiers.HasFlag(FieldModifiers.Public)) + { + continue; + } + + AddTypeReference(references[current], field.Type, nodes, serializationReferenceNamesByType); + if (!publicOnly) + { + AddAttributes(references[current], field.Attributes, nodes, serializationReferenceNamesByType, includeArguments: false); + } + } + + foreach (var constructor in provider.Constructors) + { + if (publicOnly && !IsPublic(constructor.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(references[current], constructor.Signature, nodes, serializationReferenceNamesByType, includeAttributes: !publicOnly, includeAttributeArguments: false); + } + + foreach (var method in provider.Methods) + { + if (method.IsMethodSuppressed()) + { + continue; + } + + if (publicOnly && !IsPublic(method.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(references[current], method.Signature, nodes, serializationReferenceNamesByType, includeAttributes: !publicOnly, includeAttributeArguments: false); + if (!publicOnly) + { + AddTypeReference(references[current], GetCollectionDefinitionType(method), nodes, serializationReferenceNamesByType); + } + } + } + + return new ProviderReferenceGraph(nodes, references); + } + + private static Dictionary GetSerializationProviderNamesByType(IReadOnlyList generatedProviders) + { + var namesByType = new Dictionary>(StringComparer.Ordinal); + foreach (var provider in generatedProviders) + { + if (provider.SerializationProviders.Count == 0) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!namesByType.TryGetValue(providerName, out var serializationProviderNames)) + { + serializationProviderNames = new HashSet(StringComparer.Ordinal); + namesByType.Add(providerName, serializationProviderNames); + } + + foreach (var serializationProvider in provider.SerializationProviders) + { + serializationProviderNames.Add(GetProviderTypeName(serializationProvider.Type)); + } + } + + var result = new Dictionary(StringComparer.Ordinal); + foreach (var (providerName, serializationProviderNames) in namesByType) + { + result.Add(providerName, [.. serializationProviderNames]); + } + + return result; + } + + private static CSharpType? GetCollectionDefinitionType(MethodProvider method) + { + var property = method.GetType().GetProperty("CollectionDefinition"); + return property?.GetValue(method) is TypeProvider collectionDefinition + ? collectionDefinition.Type + : null; + } + + private static bool IsPublic(MethodSignatureModifiers modifiers) => modifiers.HasFlag(MethodSignatureModifiers.Public); + private static bool IsPublic(FieldModifiers modifiers) => modifiers.HasFlag(FieldModifiers.Public); + + private static TypeSignatureModifiers MakeInternal(TypeSignatureModifiers modifiers) + => (modifiers & ~(TypeSignatureModifiers.Public | TypeSignatureModifiers.Private | TypeSignatureModifiers.Protected)) | TypeSignatureModifiers.Internal; + + private static TypeSignatureModifiers MakePublic(TypeSignatureModifiers modifiers) + => (modifiers & ~(TypeSignatureModifiers.Internal | TypeSignatureModifiers.Private | TypeSignatureModifiers.Protected)) | TypeSignatureModifiers.Public; + + private static Dictionary> CloneReferences(IReadOnlyDictionary> references) + { + var clone = new Dictionary>(StringComparer.Ordinal); + foreach (var (name, referencedNames) in references) + { + clone.Add(name, new HashSet(referencedNames, StringComparer.Ordinal)); + } + + return clone; + } + + private static void AddDerivedModelReferences( + IReadOnlyList providers, + HashSet nodes, + Dictionary> references, + HashSet publicBaseModels, + HashSet generatedDiscriminatorBaseNames) + { + var modelProviders = new List(); + var discriminatorProviders = new List(); + var discriminatorBaseNames = new HashSet(StringComparer.Ordinal); + foreach (var provider in providers) + { + if (provider is not ModelProvider modelProvider || + !modelProvider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + modelProviders.Add(modelProvider); + + if (modelProvider.DiscriminatorProperty != null) + { + discriminatorBaseNames.Add(GetProviderTypeName(modelProvider.Type)); + } + + if (!modelProvider.IsUnknownDiscriminatorModel && + (modelProvider.DiscriminatorProperty != null || modelProvider.DiscriminatorValue != null)) + { + discriminatorProviders.Add(modelProvider); + } + } + + discriminatorBaseNames.UnionWith(generatedDiscriminatorBaseNames); + var addedReference = true; + while (addedReference) + { + addedReference = false; + foreach (var provider in discriminatorProviders) + { + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName)) + { + continue; + } + + if (!publicBaseModels.Contains(providerName)) + { + continue; + } + + foreach (var derivedModel in provider.DerivedModels) + { + if (derivedModel.IsUnknownDiscriminatorModel || + !derivedModel.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var before = references[providerName].Count; + AddTypeReference(references[providerName], derivedModel.Type, nodes); + var derivedName = GetProviderTypeName(derivedModel.Type); + if (nodes.Contains(derivedName) && publicBaseModels.Add(derivedName) || references[providerName].Count != before) + { + addedReference = true; + } + } + } + + foreach (var provider in modelProviders) + { + if (provider.IsUnknownDiscriminatorModel || + !provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName)) + { + continue; + } + + var baseTypeName = provider.BaseType == null ? null : GetProviderTypeName(provider.BaseType); + if (baseTypeName == null || + !discriminatorBaseNames.Contains(baseTypeName) || + !nodes.Contains(baseTypeName) || + !publicBaseModels.Contains(baseTypeName)) + { + continue; + } + + var before = references[baseTypeName].Count; + references[baseTypeName].Add(providerName); + if (publicBaseModels.Add(providerName) || references[baseTypeName].Count != before) + { + addedReference = true; + } + } + } + } + + private static void AddBasePreservedReferences( + IReadOnlyList providers, + HashSet nodes, + IReadOnlyDictionary> references, + HashSet reachableTypes) + { + var basePreservedRoots = new HashSet(StringComparer.Ordinal); + var addedRoot = true; + while (addedRoot) + { + addedRoot = false; + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName) || reachableTypes.Contains(providerName) || basePreservedRoots.Contains(providerName)) + { + continue; + } + + var baseTypeName = provider.BaseType == null ? null : GetProviderTypeName(provider.BaseType); + if (baseTypeName == null || !reachableTypes.Contains(baseTypeName)) + { + continue; + } + + if (basePreservedRoots.Add(providerName)) + { + addedRoot = true; + } + } + + if (addedRoot) + { + reachableTypes.UnionWith(GetReachableTypes(basePreservedRoots, references)); + } + } + } + + private static IReadOnlyList GetGeneratedProviders(IReadOnlyList providers) + { + var generatedProviders = new List(); + foreach (var provider in providers) + { + AddGeneratedProvider(generatedProviders, provider); + } + + return generatedProviders; + } + + private static void AddGeneratedProvider(List generatedProviders, TypeProvider provider) + { + generatedProviders.Add(provider); + foreach (var nestedType in provider.NestedTypes) + { + AddGeneratedProvider(generatedProviders, nestedType); + } + + foreach (var serializationProvider in provider.SerializationProviders) + { + AddGeneratedProvider(generatedProviders, serializationProvider); + } + } + + private static void AddGeneratedBodyReferences(IReadOnlyList providers, ProviderReferenceGraph graph) + { + foreach (var (provider, isSerializationProvider) in GetBodyReferenceProviders(providers)) + { + if (IsModelFactoryProvider(provider) || + !IsGeneratedBodyReferenceCandidate(provider, isSerializationProvider)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!graph.Nodes.Contains(providerName)) + { + continue; + } + AddProviderBodyDependencyTypes( + graph.References[providerName], + GetNonEnumStructuredBodyReferenceTypes(provider, graph.Nodes), + graph.Nodes); + AddProviderBodyDependencyTypes(graph.References[providerName], provider.BodyDependencyTypes, graph.Nodes); + AddProviderInfrastructureReferences(graph.References[providerName], provider, graph.Nodes); + AddHelperDependencies(graph.References[providerName], provider.HelperDependencyTypes, graph.Nodes, graph.References[providerName]); + } + } + + private static IReadOnlyList GetNonEnumStructuredBodyReferenceTypes(TypeProvider provider, HashSet nodes) + { + var references = new List(); + foreach (var dependency in CollectStructuredBodyReferenceTypes(provider)) + { + if (!IsEnumProviderDependency(dependency, nodes)) + { + references.Add(dependency); + } + } + + return references; + } + + private static void AddProviderInfrastructureReferences(HashSet references, TypeProvider provider, HashSet nodes) + { + AddMatchingName(references, "ProviderConstants", nodes); + AddMatchingName(references, "TypeFormatters", nodes); + + if (provider.SerializationProviders.Count > 0) + { + AddSerializationExtensionReferences(references, provider, nodes); + } + + if (IsSerializationProvider(provider)) + { + AddMatchingName(references, "Optional", nodes); + AddMatchingName(references, "Utf8JsonRequestContent", nodes); + AddMatchingName(references, "ModelSerializationExtensions", nodes); + AddSerializationExtensionReferences(references, provider, nodes); + } + + foreach (var method in provider.Methods) + { + if (method.IsMethodSuppressed()) + { + continue; + } + + AddMethodInfrastructureReferences(references, method, nodes); + } + } + + private static void AddSerializationExtensionReferences(HashSet references, TypeProvider provider, HashSet nodes) + { + AddSerializationExtensionReferences(references, provider.Type, nodes); + AddSerializationExtensionReferences(references, provider.BaseType, nodes); + foreach (var implementedType in provider.Implements) + { + AddSerializationExtensionReferences(references, implementedType, nodes); + } + + foreach (var property in provider.Properties) + { + AddSerializationExtensionReferences(references, property.Type, nodes); + } + + foreach (var field in provider.Fields) + { + AddSerializationExtensionReferences(references, field.Type, nodes); + } + + foreach (var constructor in provider.Constructors) + { + AddSerializationExtensionReferences(references, constructor.Signature.ReturnType, nodes); + foreach (var parameter in constructor.Signature.Parameters) + { + AddSerializationExtensionReferences(references, parameter.Type, nodes); + } + } + + foreach (var method in provider.Methods) + { + if (method.IsMethodSuppressed()) + { + continue; + } + + AddSerializationExtensionReferences(references, method.Signature.ReturnType, nodes); + foreach (var parameter in method.Signature.Parameters) + { + AddSerializationExtensionReferences(references, parameter.Type, nodes); + } + } + } + + private static void AddSerializationExtensionReferences(HashSet references, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + AddMatchingName(references, $"{type.Name}Extensions", nodes); + foreach (var argument in type.Arguments) + { + AddSerializationExtensionReferences(references, argument, nodes); + } + } + + private static void AddMethodInfrastructureReferences(HashSet references, MethodProvider method, HashSet nodes) + { + AddReturnTypeInfrastructureReferences(references, method.Signature.ReturnType, nodes); + } + + private static void AddReturnTypeInfrastructureReferences(HashSet references, CSharpType? returnType, HashSet nodes) + { + var type = UnwrapTask(returnType); + if (type == null) + { + return; + } + + var typeName = StripGenericArity(type.Name); + if (string.Equals(typeName, "Pageable", StringComparison.Ordinal)) + { + AddMatchingName(references, "PageableWrapper", nodes); + } + else if (string.Equals(typeName, "AsyncPageable", StringComparison.Ordinal)) + { + AddMatchingName(references, "AsyncPageableWrapper", nodes); + } + else if (string.Equals(typeName, "ArmOperation", StringComparison.Ordinal)) + { + AddMatchingNamesWithSimpleNameSuffix(references, "ArmOperation", nodes); + AddMatchingNamesWithSimpleNameSuffix(references, "OperationSource", nodes); + if (type.Arguments.Count > 0) + { + AddMatchingName(references, $"{BuildOperationSourceTypeName(type.Arguments[0])}OperationSource", nodes); + } + } + } + + private static CSharpType? UnwrapTask(CSharpType? type) + { + var typeName = type == null ? null : StripGenericArity(type.Name); + if ((string.Equals(typeName, "Task", StringComparison.Ordinal) || + string.Equals(typeName, "ValueTask", StringComparison.Ordinal)) && + type?.Arguments.Count > 0) + { + return type.Arguments[0]; + } + + return type; + } + + private static string BuildOperationSourceTypeName(CSharpType type) + { + var argumentNames = string.Join("", type.Arguments.Select(BuildOperationSourceTypeName)); + return $"{type.Name}{(argumentNames.Length > 0 ? "Of" : string.Empty)}{argumentNames}"; + } + + private static IReadOnlyList CollectStructuredBodyReferenceTypes(TypeProvider provider) + { + var references = new HashSet(); + var visited = new HashSet(ReferenceEqualityComparer.Instance); + + foreach (var field in provider.Fields) + { + CollectStructuredBodyReferenceTypes(field.InitializationValue, references, visited); + } + + foreach (var property in provider.Properties) + { + CollectStructuredBodyReferenceTypes(property.Body, references, visited); + } + + foreach (var constructor in provider.Constructors) + { + CollectStructuredBodyReferenceTypes(constructor.BodyExpression, references, visited); + CollectStructuredBodyReferenceTypes(constructor.BodyStatements, references, visited); + } + + foreach (var method in provider.Methods) + { + if (method.IsMethodSuppressed()) + { + continue; + } + + CollectStructuredBodyReferenceTypes(method.BodyExpression, references, visited); + CollectStructuredBodyReferenceTypes(method.BodyStatements, references, visited); + } + + return [.. references]; + } + + private static void CollectStructuredBodyReferenceTypes(object? value, HashSet references, HashSet visited) + { + switch (value) + { + case null: + case string: + case FormattableString: + return; + } + + if (!value.GetType().IsValueType && !visited.Add(value)) + { + return; + } + + switch (value) + { + case CSharpType type: + references.Add(type); + return; + case Type type: + references.Add(type); + return; + case ParameterProvider parameter: + references.Add(parameter.Type); + CollectStructuredBodyReferenceTypes(parameter.DefaultValue, references, visited); + CollectStructuredBodyReferenceTypes(parameter.InitializationValue, references, visited); + return; + case MethodSignatureBase signature: + CollectStructuredBodyReferenceTypes(signature.ReturnType, references, visited); + CollectStructuredBodyReferenceTypes(signature.Parameters, references, visited); + return; + case KeyValuePair positionalArgument: + CollectStructuredBodyReferenceTypes(positionalArgument.Value, references, visited); + return; + case FieldProvider field: + references.Add(field.Type); + CollectStructuredBodyReferenceTypes(field.InitializationValue, references, visited); + return; + } + + if (IsStructuredBodyReferenceObject(value)) + { + foreach (var property in value.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance)) + { + if (property.GetIndexParameters().Length > 0) + { + continue; + } + + CollectStructuredBodyReferenceTypes(property.GetValue(value), references, visited); + } + + return; + } + + if (value is not IEnumerable values) + { + return; + } + + foreach (var item in values) + { + CollectStructuredBodyReferenceTypes(item, references, visited); + } + } + + private static bool IsEnumProviderDependency(CSharpType dependency, HashSet nodes) + { + var providerName = GetProviderTypeName(dependency); + if (!nodes.Contains(providerName)) + { + return false; + } + + foreach (var provider in CodeModelGenerator.Instance.OutputLibrary.TypeProviders) + { + if (provider is EnumProvider && + string.Equals(GetProviderTypeName(provider.Type), providerName, StringComparison.Ordinal)) + { + return true; + } + } + + return false; + } + + private static bool IsStructuredBodyReferenceObject(object value) => + value is ValueExpression || + value is MethodBodyStatement || + value is PropertyBody; + + private static void AddProviderBodyDependencyTypes( + HashSet references, + IReadOnlyList dependencies, + HashSet nodes, + bool includeSimpleNameReferences = false) + { + foreach (var dependency in dependencies) + { + AddProviderBodyDependencyType(references, dependency, nodes, includeSimpleNameReferences); + } + } + + private static void AddProviderBodyDependencyType( + HashSet references, + CSharpType? dependency, + HashSet nodes, + bool includeSimpleNameReferences) + { + if (dependency == null) + { + return; + } + + AddTypeReference(references, dependency, nodes); + if (includeSimpleNameReferences) + { + AddMatchingName(references, dependency.Name, nodes); + } + if (nodes.Contains(GetProviderTypeName(dependency))) + { + AddMatchingName(references, $"{dependency.Name}Extensions", nodes); + } + else if (string.Equals(dependency.Name, "RequestContext", StringComparison.Ordinal)) + { + AddMatchingName(references, "RequestContextExtensions", nodes); + } + + foreach (var argument in dependency.Arguments) + { + AddProviderBodyDependencyType(references, argument, nodes, includeSimpleNameReferences); + } + } + + private static IReadOnlyList<(TypeProvider Provider, bool IsSerializationProvider)> GetBodyReferenceProviders(IReadOnlyList providers) + { + var bodyReferenceProviders = new List<(TypeProvider Provider, bool IsSerializationProvider)>(); + foreach (var provider in providers) + { + bodyReferenceProviders.Add((provider, false)); + foreach (var serializationProvider in provider.SerializationProviders) + { + bodyReferenceProviders.Add((serializationProvider, true)); + } + } + + return bodyReferenceProviders; + } + + private static bool IsGeneratedBodyReferenceCandidate(TypeProvider provider, bool isSerializationProvider) + { + if (provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) + { + return true; + } + + return provider.IsReferenceMapRoot || + isSerializationProvider || + provider.IncludeGeneratedBodyReferences || + provider.HelperDependencyTypes.Count > 0 || + provider.BodyDependencyTypes.Count > 0; + } + + private static HashSet GetRootNames( + IReadOnlyList providers, + HashSet nodes, + HashSet helperRoots, + bool includeModelFactory, + bool includeAdditionalRoots, + bool includeUnionVariantRoots, + bool publicClientRootsOnly) + { + var generator = CodeModelGenerator.Instance; + var roots = new HashSet(StringComparer.Ordinal); + var modelFactoryName = GetProviderTypeName(generator.OutputLibrary.ModelFactory.Value.Type); + + foreach (var provider in providers) + { + var name = GetProviderTypeName(provider.Type); + if (IsReferenceMapRootProvider(provider, publicClientRootsOnly) || + includeAdditionalRoots && IsAdditionalRootProvider(provider, generator.AdditionalRootTypes, nodes) || + includeModelFactory && string.Equals(name, modelFactoryName, StringComparison.Ordinal) || + includeModelFactory && helperRoots.Contains(name)) + { + roots.Add(name); + } + } + + AddLastContractModelFactorySignatureRoots(providers, roots, nodes); + + if (!includeUnionVariantRoots) + { + return roots; + } + + AddUnionVariantRoots(roots, providers, nodes); + + return roots; + } + + private static void AddLastContractModelFactorySignatureRoots(IReadOnlyList providers, HashSet roots, HashSet nodes) + { + foreach (var provider in providers) + { + if (!IsModelFactoryProvider(provider)) + { + continue; + } + + foreach (var method in provider.LastContractView?.Methods ?? []) + { + if (!method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Public) || + IsImplementationOnlyModelFactoryMethod(method)) + { + continue; + } + + AddTypeReference(roots, method.Signature.ReturnType, nodes); + foreach (var parameter in method.Signature.Parameters) + { + AddTypeReference(roots, parameter.Type, nodes); + } + } + } + } + + private static void AddUnionVariantRoots(HashSet roots, IReadOnlyList providers, HashSet nodes) + { + var unionVariantTypesToKeep = CodeModelGenerator.Instance.TypeFactory.UnionVariantTypesToKeep; + foreach (var provider in GetGeneratedProviders(providers)) + { + if (!unionVariantTypesToKeep.Contains(provider.Type.Name) || + string.Equals(provider.Type.Namespace, "TypeSpec.Http", StringComparison.Ordinal)) + { + continue; + } + + AddMatchingName(roots, GetProviderTypeName(provider.Type), nodes); + } + } + + private static bool ShouldUseUnionVariantFallbackRoots() => + !HasApiBaselineDirectory() && + CodeModelGenerator.Instance.SourceInputModel.LastContract == null; + + private static bool IsImplementationOnlyModelFactoryMethod(MethodProvider method) + { + var returnType = method.Signature.ReturnType; + if (returnType == null) + { + return true; + } + + var returnTypeName = GetSimpleName(GetProviderTypeName(returnType)); + return returnTypeName.StartsWith("Paged", StringComparison.Ordinal) || + returnTypeName.EndsWith("Request", StringComparison.Ordinal); + } + + private static void RemoveMethodsFromModelFactory(HashSet namesToRemove) + { + if (namesToRemove.Count == 0) + { + return; + } + + var modelFactory = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value; + _preWriteModelFactory = modelFactory; + _preWriteModelFactoryMethods ??= [.. modelFactory.Methods]; + var methodsToKeep = new List(); + foreach (var method in modelFactory.Methods) + { + if (!namesToRemove.Contains(method.Signature.Name)) + { + methodsToKeep.Add(method); + } + } + + modelFactory.Update(methods: methodsToKeep); + } + + private static HashSet GetPostProcessorDeclaredNodes(IReadOnlyList providers, HashSet nodes, bool publicOnly) + { + var generator = CodeModelGenerator.Instance; + var excludedNames = generator.NonRootTypes; + var declaredNodes = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (IsModelFactoryProvider(provider)) + { + continue; + } + + if (publicOnly && !provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var name = GetProviderTypeName(provider.Type); + if (!nodes.Contains(name) || + excludedNames.Contains(name) || + excludedNames.Contains(GetSimpleName(name))) + { + continue; + } + + declaredNodes.Add(name); + } + + return declaredNodes; + } + + private static bool IsKept(CSharpType type, HashSet roots, HashSet nodes) + { + var providerName = GetProviderTypeName(type); + if (roots.Contains(providerName) && nodes.Contains(providerName)) + { + return true; + } + + if (!roots.Contains(type.Name)) + { + return false; + } + + var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); + return simpleNameLookup.TryGetValue(type.Name, out var matches) && + matches.Length == 1 && + string.Equals(matches[0], providerName, StringComparison.Ordinal); + } + + private static bool IsReferenceMapRootProvider(TypeProvider provider, bool publicOnly) => + provider.IsReferenceMapRoot && + (!publicOnly || !HasApiBaselineDirectory() && provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + + private static bool IsAdditionalRootProvider(TypeProvider provider, HashSet roots, HashSet nodes) + { + if (provider.DeclaringTypeProvider != null || !IsKept(provider.Type, roots, nodes)) + { + return false; + } + + return provider is not ModelProvider && provider is not EnumProvider; + } + + private static bool HasApiBaselineDirectory() + { + var projectDirectory = CodeModelGenerator.Instance.Configuration.ProjectDirectory; + return !string.IsNullOrEmpty(projectDirectory) && + Directory.Exists(Path.GetFullPath(Path.Combine(projectDirectory, "..", "api"))); + } + + private static bool IsModelFactoryProvider(TypeProvider provider) + => provider is ModelFactoryProvider; + + private static HashSet GetHelperRootNames( + IReadOnlyList providers, + HashSet nodes, + HashSet reachableTypes, + IReadOnlyDictionary>? references = null) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + var isModelFactory = IsModelFactoryProvider(provider); + if (!reachableTypes.Contains(providerName) && !isModelFactory) + { + continue; + } + + AddHelperDependencies(roots, provider.HelperDependencyTypes, nodes, references == null ? null : references[providerName]); + + foreach (var property in provider.Properties) + { + AddInitializationHelperRoot(roots, property.Type, nodes); + AddParameterValidationHelperRoot(roots, property.AsParameter, nodes); + } + + foreach (var field in provider.Fields) + { + AddParameterValidationHelperRoot(roots, field.AsParameter, nodes); + } + + foreach (var constructor in provider.Constructors) + { + foreach (var parameter in constructor.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + } + } + + foreach (var method in provider.Methods) + { + // Only factory methods for reachable models can instantiate collection helpers. + if (isModelFactory && + (method.Signature.ReturnType == null || !reachableTypes.Contains(GetProviderTypeName(method.Signature.ReturnType)))) + { + continue; + } + + foreach (var parameter in method.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + if (isModelFactory) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, parameter.Type, nodes); + } + } + } + } + + return roots; + } + + private static void AddParameterValidationHelperRoot(HashSet roots, ParameterProvider parameter, HashSet nodes) + { + if (parameter.Validation != ParameterValidationType.None) + { + AddMatchingName(roots, "Argument", nodes); + } + } + + private static void AddHelperDependencies( + HashSet roots, + IReadOnlyList dependencies, + HashSet nodes, + HashSet? referencedNames) + { + foreach (var dependency in dependencies) + { + if (referencedNames == null) + { + AddTypeReference(roots, dependency, nodes); + continue; + } + + var matches = new HashSet(StringComparer.Ordinal); + AddTypeReference(matches, dependency, nodes); + foreach (var match in matches) + { + if (referencedNames.Contains(match)) + { + roots.Add(match); + } + } + } + } + + private static void RemoveUnusedRequestHeaderExtensionsRoot( + HashSet roots, + IReadOnlyDictionary> references, + IReadOnlyList providers) + { + var hasCustomReference = HasCustomRequestHeaderExtensionsReference(providers); + if (hasCustomReference) + { + return; + } + + var unusedRequestHeaderExtensions = new List(); + foreach (var root in roots) + { + if (IsRequestHeadersExtensionsRoot(root) && + !HasExternalReference(root, references)) + { + unusedRequestHeaderExtensions.Add(root); + } + } + + roots.ExceptWith(unusedRequestHeaderExtensions); + } + + private static bool HasExternalReference(string root, IReadOnlyDictionary> references) + { + foreach (var (source, sourceReferences) in references) + { + if (!string.Equals(source, root, StringComparison.Ordinal) && + sourceReferences.Contains(root)) + { + return true; + } + } + + return false; + } + + private static bool IsRequestHeadersExtensionsRoot(string root) => + root.EndsWith(".RequestHeaderExtensions", StringComparison.Ordinal) || + root.EndsWith(".RequestHeadersExtensions", StringComparison.Ordinal); + + private static bool HasCustomRequestHeaderExtensionsReference(IReadOnlyList providers) + { + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + if (customCodeView is NamedTypeSymbolProvider) + { + if (HasRequestHeaderExtensionsDependency(customCodeView.HelperDependencyTypes) || + HasRequestHeaderExtensionsDependency(customCodeView.BodyDependencyTypes) || + HasRequestHeaderExtensionsDependency(customCodeView.SignatureDependencyTypes)) + { + return true; + } + + continue; + } + + if (HasRequestHeaderExtensionsDependency(customCodeView.HelperDependencyTypes) || + HasRequestHeaderExtensionsDependency(customCodeView.BodyDependencyTypes) || + HasRequestHeaderExtensionsMethodDependency(customCodeView.Methods) || + HasRequestHeaderExtensionsPropertyDependency(customCodeView.Properties) || + HasRequestHeaderExtensionsFieldDependency(customCodeView.Fields)) + { + return true; + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsDependency(IEnumerable dependencies) + { + foreach (var dependency in dependencies) + { + if (IsRequestHeaderExtensionsDependency(dependency)) + { + return true; + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsMethodDependency(IReadOnlyList methods) + { + foreach (var method in methods) + { + if (IsRequestHeaderExtensionsDependency(method.Signature.ReturnType)) + { + return true; + } + + foreach (var parameter in method.Signature.Parameters) + { + if (IsRequestHeaderExtensionsDependency(parameter.Type)) + { + return true; + } + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsPropertyDependency(IReadOnlyList properties) + { + foreach (var property in properties) + { + if (IsRequestHeaderExtensionsDependency(property.Type)) + { + return true; + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsFieldDependency(IReadOnlyList fields) + { + foreach (var field in fields) + { + if (IsRequestHeaderExtensionsDependency(field.Type)) + { + return true; + } + } + + return false; + } + + private static bool IsRequestHeaderExtensionsDependency(string name) + => string.Equals(name, "RequestHeaderExtensions", StringComparison.Ordinal) || + string.Equals(name, "SetDelimited", StringComparison.Ordinal); + + private static bool IsRequestHeaderExtensionsDependency(CSharpType? type) + { + if (type == null) + { + return false; + } + + if (IsRequestHeaderExtensionsDependency(type.Name)) + { + return true; + } + + foreach (var argument in type.Arguments) + { + if (IsRequestHeaderExtensionsDependency(argument)) + { + return true; + } + } + + return false; + } + + private static bool IsSerializationProvider(TypeProvider provider) + { + var relativePath = provider.RelativeFilePath.Replace('\\', '/'); + return relativePath.EndsWith(".Serialization.cs", StringComparison.Ordinal) || + relativePath.EndsWith(".Serialization.Multipart.cs", StringComparison.Ordinal); + } + + private static void AddInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + var initializationType = type.PropertyInitializationType; + if (!string.Equals(initializationType.FullyQualifiedName, type.FullyQualifiedName, StringComparison.Ordinal)) + { + AddMatchingName(roots, initializationType.Name, nodes); + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.ListInitializationType, nodes); + } + + if (type.IsDictionary) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType, nodes); + } + + foreach (var argument in type.Arguments) + { + AddInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddModelFactoryCollectionInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.ListInitializationType, nodes); + } + + if (type.IsDictionary) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType, nodes); + } + + foreach (var argument in type.Arguments) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddMatchingName(HashSet target, string name, HashSet nodes) + { + if (nodes.Contains(name)) + { + target.Add(name); + return; + } + + var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); + if (!simpleNameLookup.TryGetValue(name, out var matches)) + { + return; + } + + foreach (var match in matches) + { + target.Add(match); + } + } + + private static void AddMatchingNamesWithSimpleNameSuffix(HashSet target, string suffix, HashSet nodes) + { + foreach (var node in nodes) + { + if (GetSimpleName(node).EndsWith(suffix, StringComparison.Ordinal)) + { + target.Add(node); + } + } + } + + private static Dictionary BuildSimpleNameLookup(HashSet nodes) + { + var lookup = new Dictionary>(StringComparer.Ordinal); + foreach (var node in nodes) + { + var simpleName = StripGenericArity(GetSimpleName(node)); + if (!lookup.TryGetValue(simpleName, out var matchingNodes)) + { + matchingNodes = []; + lookup.Add(simpleName, matchingNodes); + } + + matchingNodes.Add(node); + } + + var result = new Dictionary(StringComparer.Ordinal); + foreach (var (simpleName, matchingNodes) in lookup) + { + result.Add(simpleName, [.. matchingNodes]); + } + + return result; + } + + private static HashSet GetReachableTypes(HashSet roots, IReadOnlyDictionary> references) + { + return GetReachableTypes(roots, references, expandableNodes: null); + } + + private static HashSet GetReachableTypes( + HashSet roots, + IReadOnlyDictionary> references, + HashSet? expandableNodes) + { + var reachable = new HashSet(StringComparer.Ordinal); + var queue = new Queue(roots); + while (queue.Count > 0) + { + var current = queue.Dequeue(); + if (!reachable.Add(current)) + { + continue; + } + + if (expandableNodes != null && !expandableNodes.Contains(current)) + { + continue; + } + + if (!references.TryGetValue(current, out var children)) + { + continue; + } + + foreach (var child in children) + { + queue.Enqueue(child); + } + } + + return reachable; + } + + private static bool HasPublicApiPredecessor( + string name, + IReadOnlyDictionary> references, + HashSet publicizeReachable, + HashSet generatedImplementationInternalDeclarations) + { + foreach (var (owner, children) in references) + { + if (!publicizeReachable.Contains(owner) || + string.Equals(owner, name, StringComparison.Ordinal) || + generatedImplementationInternalDeclarations.Contains(owner) || + !children.Contains(name)) + { + continue; + } + + return true; + } + + return false; + } + + private static void AddSignatureReferences( + HashSet references, + MethodSignatureBase signature, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType, + bool includeAttributes = true, + bool includeAttributeArguments = true) + { + AddTypeReference(references, signature.ReturnType, nodes, serializationProviderNamesByType); + if (includeAttributes) + { + AddAttributes(references, signature.Attributes, nodes, serializationProviderNamesByType, includeAttributeArguments); + } + + foreach (var parameter in signature.Parameters) + { + AddTypeReference(references, parameter.Type, nodes, serializationProviderNamesByType); + if (includeAttributes) + { + AddAttributes(references, parameter.Attributes, nodes, serializationProviderNamesByType, includeAttributeArguments); + } + } + + if (signature is MethodSignature methodSignature) + { + AddTypeReference(references, methodSignature.ExplicitInterface, nodes, serializationProviderNamesByType); + if (methodSignature.GenericArguments != null) + { + foreach (var genericArgument in methodSignature.GenericArguments) + { + AddTypeReference(references, genericArgument, nodes, serializationProviderNamesByType); + } + } + + if (methodSignature.GenericParameterConstraints != null) + { + foreach (var constraint in methodSignature.GenericParameterConstraints) + { + AddTypeReference(references, constraint.Type, nodes, serializationProviderNamesByType); + } + } + } + + if (signature is ConstructorSignature constructorSignature) + { + AddTypeReference(references, constructorSignature.Type, nodes, serializationProviderNamesByType); + } + } + + private static void AddAttributes( + HashSet references, + IReadOnlyList attributes, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType, + bool includeArguments) + { + foreach (var attribute in attributes) + { + AddTypeReference(references, attribute.Type, nodes, serializationProviderNamesByType); + if (!includeArguments) + { + continue; + } + + foreach (var argument in attribute.Arguments) + { + AddAttributeArgumentReference(references, argument, nodes, serializationProviderNamesByType); + } + + foreach (var (_, argument) in attribute.PositionalArguments) + { + AddAttributeArgumentReference(references, argument, nodes, serializationProviderNamesByType); + } + } + } + + private static bool IsAttributeNamed(AttributeStatement attribute, string name) + => string.Equals(attribute.Type.Name, name, StringComparison.Ordinal) || + string.Equals(attribute.Type.Name, $"{name}Attribute", StringComparison.Ordinal); + + private static void AddAttributeArgumentReference( + HashSet references, + ValueExpression argument, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType) + { + if (argument is TypeOfExpression typeOf) + { + AddTypeReference(references, typeOf.Type, nodes, serializationProviderNamesByType); + AddMatchingName(references, typeOf.Type.Name, nodes); + } + } + + private static void AddTypeReference( + HashSet references, + CSharpType? type, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType = null) + { + if (type == null) + { + return; + } + + if (type.IsArray) + { + AddTypeReference(references, type.ElementType, nodes, serializationProviderNamesByType); + return; + } + + var providerTypeName = GetProviderTypeName(type); + if (nodes.Contains(providerTypeName)) + { + references.Add(providerTypeName); + if (serializationProviderNamesByType != null && serializationProviderNamesByType.TryGetValue(providerTypeName, out var serializationProviderNames)) + { + foreach (var serializationProviderName in serializationProviderNames) + { + references.Add(serializationProviderName); + } + } + } + + AddTypeReference(references, type.BaseType, nodes, serializationProviderNamesByType); + AddTypeReference(references, type.DeclaringType, nodes, serializationProviderNamesByType); + foreach (var argument in type.Arguments) + { + AddTypeReference(references, argument, nodes, serializationProviderNamesByType); + } + } + + private static string GetSimpleName(string fullyQualifiedName) + { + var lastDot = fullyQualifiedName.LastIndexOf('.'); + return lastDot < 0 ? fullyQualifiedName : fullyQualifiedName.Substring(lastDot + 1); + } + + private static string? GetNamespaceName(string fullyQualifiedName) + { + var lastDot = fullyQualifiedName.LastIndexOf('.'); + return lastDot < 0 ? null : fullyQualifiedName.Substring(0, lastDot); + } + + private static string GetProviderTypeName(CSharpType type) + { + var name = type.Arguments.Count > 0 && !type.Name.Contains('`', StringComparison.Ordinal) + ? $"{type.Name}`{type.Arguments.Count}" + : type.Name; + return string.IsNullOrEmpty(type.Namespace) ? name : $"{type.Namespace}.{name}"; + } + + private static string StripGenericArity(string name) + { + var tick = name.IndexOf('`'); + return tick < 0 ? name : name.Substring(0, tick); + } + + private sealed record ProviderReferenceGraph( + HashSet Nodes, + Dictionary> References); + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs new file mode 100644 index 00000000000..eafe1d9d546 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; + +namespace Microsoft.TypeSpec.Generator +{ + internal sealed record ProviderReferenceMapResult( + HashSet InternalizeCandidates, + HashSet PublicizeCandidates, + HashSet RemoveCandidates) + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/TypeProviderWriter.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/TypeProviderWriter.cs index 49fe9723973..eb07aa4519f 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/TypeProviderWriter.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/TypeProviderWriter.cs @@ -45,7 +45,7 @@ private bool IsPublicContext(TypeProvider provider) private void WriteType(CodeWriter writer) { - if (IsPublicContext(_provider)) + if (_provider.PreserveTypeXmlDocs || IsPublicContext(_provider)) { writer.WriteXmlDocsNoScope(_provider.XmlDocs); } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/NamedTypeSymbolProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/NamedTypeSymbolProvider.cs index ed3a45d54e8..26cc96af377 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/NamedTypeSymbolProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/NamedTypeSymbolProvider.cs @@ -22,6 +22,7 @@ internal sealed class NamedTypeSymbolProvider : TypeProvider { private INamedTypeSymbol _namedTypeSymbol; private readonly Compilation _compilation; + private string? _metadataName; private TypeProvider? _baseTypeProvider; public NamedTypeSymbolProvider(INamedTypeSymbol namedTypeSymbol, Compilation compilation) @@ -30,6 +31,23 @@ public NamedTypeSymbolProvider(INamedTypeSymbol namedTypeSymbol, Compilation com _compilation = compilation; } + internal string MetadataName + { + get + { + if (_metadataName != null) + { + return _metadataName; + } + + var ns = _namedTypeSymbol.ContainingNamespace.GetFullyQualifiedNameFromDisplayString(); + _metadataName = string.IsNullOrEmpty(ns) ? _namedTypeSymbol.Name : $"{ns}.{_namedTypeSymbol.Name}"; + return _metadataName; + } + } + + internal string MetadataSimpleName => _namedTypeSymbol.Name; + private protected sealed override NamedTypeSymbolProvider? BuildCustomCodeView(string? generatedTypeName = default, string? generatedTypeNamespace = default) => null; private protected sealed override TypeProvider? BuildLastContractView(string? generatedTypeName = default, string? generatedTypeNamespace = default) => null; @@ -321,6 +339,165 @@ [.. methodSymbol.Parameters.Select(p => ConvertToParameterProvider(methodSymbol, return [.. methods]; } + protected internal override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new HashSet(); + foreach (var syntaxReference in _namedTypeSymbol.DeclaringSyntaxReferences) + { + AddBodyDependencyTypes(syntaxReference.GetSyntax(), dependencies); + } + + return [.. dependencies]; + } + + protected internal override IReadOnlyList BuildSignatureDependencyTypes() + { + var dependencies = new HashSet(); + foreach (var syntaxReference in _namedTypeSymbol.DeclaringSyntaxReferences) + { + if (syntaxReference.GetSyntax() is not TypeDeclarationSyntax typeDeclaration || + !IsPublic(typeDeclaration.Modifiers)) + { + continue; + } + + AddSyntaxTypeReferences(typeDeclaration.BaseList, dependencies); + foreach (var member in typeDeclaration.Members) + { + if (IsPublicApiMember(member)) + { + AddPublicSignatureDependencyTypes(member, dependencies); + } + } + } + + return [.. dependencies]; + } + + private void AddBodyDependencyTypes(SyntaxNode syntax, HashSet dependencies) + { + AddSyntaxTypeReferences(syntax, dependencies); + + foreach (var invocation in syntax.DescendantNodes().OfType()) + { + if (GetInvocationName(invocation) == "SetDelimited") + { + dependencies.Add(CreateUnresolvedDependencyType("SetDelimited")); + } + } + } + + private static void AddPublicSignatureDependencyTypes(MemberDeclarationSyntax member, HashSet dependencies) + { + switch (member) + { + case MethodDeclarationSyntax method: + AddSyntaxTypeReferences(method.ReturnType, dependencies); + AddSyntaxTypeReferences(method.ParameterList, dependencies); + AddSyntaxTypeReferences(method.ConstraintClauses, dependencies); + break; + case ConstructorDeclarationSyntax constructor: + AddSyntaxTypeReferences(constructor.ParameterList, dependencies); + break; + case ConversionOperatorDeclarationSyntax conversion: + AddSyntaxTypeReferences(conversion.Type, dependencies); + AddSyntaxTypeReferences(conversion.ParameterList, dependencies); + break; + case OperatorDeclarationSyntax @operator: + AddSyntaxTypeReferences(@operator.ReturnType, dependencies); + AddSyntaxTypeReferences(@operator.ParameterList, dependencies); + break; + case PropertyDeclarationSyntax property: + AddSyntaxTypeReferences(property.Type, dependencies); + break; + case IndexerDeclarationSyntax indexer: + AddSyntaxTypeReferences(indexer.Type, dependencies); + AddSyntaxTypeReferences(indexer.ParameterList, dependencies); + break; + case FieldDeclarationSyntax field: + AddSyntaxTypeReferences(field.Declaration.Type, dependencies); + break; + case EventFieldDeclarationSyntax eventField: + AddSyntaxTypeReferences(eventField.Declaration.Type, dependencies); + break; + case EventDeclarationSyntax @event: + AddSyntaxTypeReferences(@event.Type, dependencies); + break; + case DelegateDeclarationSyntax @delegate: + AddSyntaxTypeReferences(@delegate.ReturnType, dependencies); + AddSyntaxTypeReferences(@delegate.ParameterList, dependencies); + AddSyntaxTypeReferences(@delegate.ConstraintClauses, dependencies); + break; + case BaseTypeDeclarationSyntax type: + AddSyntaxTypeReferences(type.BaseList, dependencies); + break; + } + } + + private static void AddSyntaxTypeReferences(SyntaxNode? node, HashSet dependencies) + { + if (node == null) + { + return; + } + + foreach (var name in node.DescendantNodesAndSelf().OfType()) + { + dependencies.Add(CreateUnresolvedDependencyType(name.Identifier.ValueText)); + } + + foreach (var name in node.DescendantNodesAndSelf().OfType()) + { + dependencies.Add(CreateUnresolvedDependencyType(name.Identifier.ValueText)); + } + } + + private static void AddSyntaxTypeReferences(IEnumerable nodes, HashSet dependencies) + { + foreach (var node in nodes) + { + AddSyntaxTypeReferences(node, dependencies); + } + } + + private static bool IsPublicApiMember(MemberDeclarationSyntax member) + => member switch + { + EventDeclarationSyntax @event => IsPublic(@event.Modifiers), + EventFieldDeclarationSyntax @event => IsPublic(@event.Modifiers), + BaseFieldDeclarationSyntax field => IsPublic(field.Modifiers), + BaseMethodDeclarationSyntax method => IsPublic(method.Modifiers), + BasePropertyDeclarationSyntax property => IsPublic(property.Modifiers), + DelegateDeclarationSyntax @delegate => IsPublic(@delegate.Modifiers), + BaseTypeDeclarationSyntax type => IsPublic(type.Modifiers), + _ => false + }; + + private static bool IsPublic(SyntaxTokenList modifiers) + => modifiers.Any(static modifier => + modifier.IsKind(SyntaxKind.PublicKeyword) || + modifier.IsKind(SyntaxKind.ProtectedKeyword)); + + private static CSharpType CreateUnresolvedDependencyType(string name) + => new( + name, + string.Empty, + isValueType: false, + isNullable: false, + declaringType: null, + args: [], + isPublic: false, + isStruct: false); + + private static string? GetInvocationName(InvocationExpressionSyntax invocation) + => invocation.Expression switch + { + IdentifierNameSyntax identifier => identifier.Identifier.ValueText, + MemberAccessExpressionSyntax memberAccess => memberAccess.Name.Identifier.ValueText, + GenericNameSyntax genericName => genericName.Identifier.ValueText, + _ => null + }; + private static bool IsPartialMethodDeclaration(IMethodSymbol methodSymbol) { foreach (var syntaxReference in methodSymbol.DeclaringSyntaxReferences) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs index 3d71670d5a9..4a073025e25 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs @@ -143,6 +143,13 @@ public XmlDocProvider XmlDocs private set => _xmlDocs = value; } + internal bool PreserveTypeXmlDocs { get; private set; } + + internal void PreserveXmlDocs() + { + PreserveTypeXmlDocs = true; + } + public string? Deprecated { get => _deprecated; @@ -292,6 +299,22 @@ private IReadOnlyList ApplyCustomizationFilter(IEnumerable SerializationProviders => _serializationProviders ??= BuildSerializationProviders(); + private IReadOnlyList? _helperDependencyTypes; + internal IReadOnlyList HelperDependencyTypes => _helperDependencyTypes ??= BuildHelperDependencyTypes(); + protected internal virtual IReadOnlyList BuildHelperDependencyTypes() => []; + + private IReadOnlyList? _bodyDependencyTypes; + internal IReadOnlyList BodyDependencyTypes => _bodyDependencyTypes ??= BuildBodyDependencyTypes(); + protected internal virtual IReadOnlyList BuildBodyDependencyTypes() => []; + + private IReadOnlyList? _signatureDependencyTypes; + internal IReadOnlyList SignatureDependencyTypes => _signatureDependencyTypes ??= BuildSignatureDependencyTypes(); + protected internal virtual IReadOnlyList BuildSignatureDependencyTypes() => []; + + protected internal virtual bool IsReferenceMapRoot => false; + + protected internal virtual bool IncludeGeneratedBodyReferences => false; + private IReadOnlyList? _attributes; public IReadOnlyList Attributes @@ -538,6 +561,7 @@ public virtual void Reset() _serializationProviders = null; _nestedTypes = null; _xmlDocs = null; + PreserveTypeXmlDocs = false; _declarationModifiers = null; _relativeFilePath = null; _customCodeView = new(() => BuildCustomCodeView()); @@ -741,75 +765,10 @@ internal void ProcessTypeForBackCompatibility() { _enumValues = updatedEnumValues; } - - // Back-compatibility processing intentionally runs after the library visitor pass so - // that the contract comparison uses the final, post-visitor member signatures (otherwise - // we could incorrectly decide whether a back-compat member is needed). As a result, any - // members synthesized above (e.g. back-compat overloads) have not been visited yet. Run - // only those newly-added members through the visitors now so visitor transforms apply to - // them as well, without re-visiting members that were already visited during the main pass. - if (newMethods != null) - { - newMethods = VisitNewMembers(newMethods, Methods, static (member, visitor) => member.Accept(visitor)); - } - if (newConstructors != null) - { - newConstructors = VisitNewMembers(newConstructors, Constructors, static (member, visitor) => visitor.VisitConstructor(member)); - } - if (newFields != null) - { - newFields = VisitNewMembers(newFields, Fields, static (member, visitor) => visitor.VisitField(member)); - } - Update(fields: newFields, methods: newMethods, constructors: newConstructors); } } - // Runs newly-added back-compatibility members through every registered visitor while leaving - // members that were already visited during the main visitor pass untouched. Membership in the - // already-visited set is determined by reference identity against the pre-Update collection. - private static IReadOnlyList VisitNewMembers( - IEnumerable allMembers, - IReadOnlyList alreadyVisited, - Func visit) - where T : class - { - var visitors = CodeModelGenerator.Instance.Visitors; - var materialized = allMembers as IReadOnlyList ?? [.. allMembers]; - if (visitors.Count == 0) - { - return materialized; - } - - var alreadyVisitedSet = new HashSet(alreadyVisited, ReferenceEqualityComparer.Instance); - var result = new List(materialized.Count); - foreach (var member in materialized) - { - if (alreadyVisitedSet.Contains(member)) - { - result.Add(member); - continue; - } - - T? visited = member; - foreach (var visitor in visitors) - { - visited = visit(visited, visitor); - if (visited == null) - { - break; - } - } - - if (visited != null) - { - result.Add(visited); - } - } - - return result; - } - protected internal virtual IReadOnlyList? BuildEnumValuesForBackCompatibility(IReadOnlyList originalEnumValues) => null; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/SourceInput/SourceInputModel.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/SourceInput/SourceInputModel.cs index a329166ee4b..4792e781338 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/SourceInput/SourceInputModel.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/SourceInput/SourceInputModel.cs @@ -22,6 +22,7 @@ public class SourceInputModel public ApiCompatBaseline ApiCompatBaseline { get; } private readonly Lazy> _nameMap; + private readonly Lazy> _customizationTypeProviders; public SourceInputModel(Compilation? customization, Compilation? lastContract) : this(customization, lastContract, ApiCompatBaseline.Empty) @@ -35,6 +36,7 @@ public SourceInputModel(Compilation? customization, Compilation? lastContract, A ApiCompatBaseline = apiCompatBaseline ?? ApiCompatBaseline.Empty; _nameMap = new(PopulateNameMap); + _customizationTypeProviders = new(PopulateCustomizationTypeProviders); } private IReadOnlyDictionary PopulateNameMap() @@ -70,6 +72,30 @@ private IReadOnlyDictionary PopulateNameMap() return FindTypeInCompilation(LastContract, ns, name, true, declaringTypeName, includeInternal: false); } + private IReadOnlyList PopulateCustomizationTypeProviders() + { + var providers = new List(); + if (Customization == null) + { + return providers; + } + + foreach (IModuleSymbol module in Customization.Assembly.Modules) + { + foreach (var type in SourceInputHelper.GetSymbols(module.GlobalNamespace)) + { + if (type is INamedTypeSymbol namedTypeSymbol) + { + providers.Add(new NamedTypeSymbolProvider(namedTypeSymbol, Customization)); + } + } + } + + return providers; + } + + internal IReadOnlyList GetCustomizationTypeProviders() => _customizationTypeProviders.Value; + private TypeProvider? FindTypeInCompilation( Compilation? compilation, string ns, diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/TypeFactory.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/TypeFactory.cs index cd70e60ea9a..46dbff6f96c 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/TypeFactory.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/TypeFactory.cs @@ -21,6 +21,9 @@ public class TypeFactory private ChangeTrackingDictionaryDefinition ChangeTrackingDictionaryProvider { get; } = new(); + private OptionalDefinition? _optionalProvider; + private OptionalDefinition OptionalProvider => _optionalProvider ??= new(); + private Dictionary InputTypeToModelProvider { get; } = []; public IDictionary CSharpTypeMap { get; } = new Dictionary(CSharpType.IgnoreNullableComparer); @@ -200,11 +203,6 @@ protected internal TypeFactory() if (modelProvider != null) { - if (model.Access == "public") - { - CodeModelGenerator.Instance.AddTypeToKeep(modelProvider); - } - CSharpTypeMap[modelProvider.Type] = modelProvider; TypeProvidersByName[modelProvider.Type.Name] = modelProvider; } @@ -299,11 +297,6 @@ protected virtual ModelFactoryProvider CreateModelFactoryCore(IEnumerable enumProvider, }; - if (enumType.Access == "public") - { - CodeModelGenerator.Instance.AddTypeToKeep(enumProvider); - } - EnumCache.Add(enumCacheKey, enumProvider); if (enumProvider != null) @@ -500,6 +493,11 @@ inputProperty.Type is InputArrayType && /// public virtual CSharpType DictionaryInitializationType => ChangeTrackingDictionaryProvider.Type; + /// + /// The type used to represent optional values in generated helper code. + /// + public virtual CSharpType OptionalType => OptionalProvider.Type; + /// /// Returns the serialization type providers for the given model type provider. /// diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Utilities/TypeSymbolExtensions.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Utilities/TypeSymbolExtensions.cs index 8af371a063c..bf8cc5eee38 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Utilities/TypeSymbolExtensions.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Utilities/TypeSymbolExtensions.cs @@ -207,7 +207,7 @@ private static CSharpType ConstructCSharpTypeFromSymbol( string ns = string.Join('.', pieces.Take(pieces.Length - 1)); CSharpType? containingType = null; - if (typeSymbol.ContainingType != null) + if (typeSymbol.ContainingType != null && typeSymbol.TypeKind != TypeKind.TypeParameter) { containingType = GetCSharpType(typeSymbol.ContainingType); ns = string.Join('.', pieces.Take(pieces.Length - 2)); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/OutputLibraryVisitorTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/OutputLibraryVisitorTests.cs index f716aad5ab3..240dd77a792 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/OutputLibraryVisitorTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/OutputLibraryVisitorTests.cs @@ -451,7 +451,7 @@ private class TestFilterVisitor : LibraryVisitor return method; } - protected internal override ConstructorProvider? VisitConstructor(ConstructorProvider constructor) + protected override ConstructorProvider? VisitConstructor(ConstructorProvider constructor) { if (constructor.Signature.Parameters.Count > 0) { @@ -469,7 +469,7 @@ private class TestFilterVisitor : LibraryVisitor return property; } - protected internal override FieldProvider? VisitField(FieldProvider field) + protected override FieldProvider? VisitField(FieldProvider field) { if (field.Name == "TestField") { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs index 28981148a4d..8e8c10dca8d 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs @@ -60,7 +60,42 @@ public async Task RemovesInvalidUsings() CollectionAssert.Contains(usings, "System"); } + [Test] + public async Task RemovesInvalidUsingsKeepsFileHeader() + { + MockHelpers.LoadMockGenerator(); + var workspace = new AdhocWorkspace(); + var projectInfo = ProjectInfo.Create( + ProjectId.CreateNewId(), + VersionStamp.Create(), + name: "TestProj", + assemblyName: "TestProj", + language: LanguageNames.CSharp) + .WithMetadataReferences(new[] + { + MetadataReference.CreateFromFile(typeof(object).Assembly.Location) + }); + + var project = workspace.AddProject(projectInfo); + var folder = Helpers.GetAssetFileOrDirectoryPath(false); + project = AddGeneratedDocument( + project, + "RootClient.cs", + "src", + "Generated", + "RootClient.cs", + File.ReadAllText(Path.Join(folder, "RootClient.cs"))); + var postProcessor = new TestPostProcessor("RootClient.cs"); + + var resultProject = await postProcessor.RemoveAsync(project); + var doc = resultProject.Documents.Single(d => d.Name == "RootClient.cs"); + var text = (await doc.GetTextAsync()).ToString(); + StringAssert.StartsWith("// Copyright (c) Microsoft Corporation. All rights reserved.", text); + StringAssert.Contains("// ", text); + StringAssert.Contains("#nullable disable", text); + StringAssert.DoesNotContain("using Missing.Namespace;", text); + } [Test] public async Task DoesNotRemoveValidUsings() { @@ -289,11 +324,14 @@ public async Task DoesNotRemoveValidAttributes() Assert.AreEqual(Helpers.GetExpectedFromFile().TrimEnd(), output, "The output should match the expected content."); } + private static Project AddGeneratedDocument(Project project, string name, string folder1, string folder2, string fileName, string text) + => project.AddDocument(name, text, folders: [folder1, folder2], filePath: Path.Join(folder1, folder2, fileName)).Project; + private class TestPostProcessor : PostProcessor { private readonly string _rootFile; - public TestPostProcessor(string rootFile, IEnumerable? nonRootTypes = null) : base([], additionalNonRootTypeNames: nonRootTypes) + public TestPostProcessor(string rootFile, IEnumerable? additionalRootTypeNames = null, IEnumerable? nonRootTypes = null, string? modelFactoryFullName = null) : base((additionalRootTypeNames ?? []).ToHashSet(), modelFactoryFullName: modelFactoryFullName, additionalNonRootTypeNames: nonRootTypes) { _rootFile = rootFile; } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/TestData/PostProcessorTests/RemovesInvalidUsingsKeepsFileHeader/RootClient.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/TestData/PostProcessorTests/RemovesInvalidUsingsKeepsFileHeader/RootClient.cs new file mode 100644 index 00000000000..572d6f590ac --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/TestData/PostProcessorTests/RemovesInvalidUsingsKeepsFileHeader/RootClient.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// + +#nullable disable + +using Missing.Namespace; + +namespace Sample +{ + public partial class RootClient + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs index 1bdf4020167..c35f1968d76 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs @@ -177,6 +177,32 @@ public async Task OmitsModelFactoryMethodIfParamTypeInternal() Assert.IsNull(modelFactory); } + // This test validates that a derived model customized to be internal does not get a + // public model factory method just because its base model remains public. + [Test] + public async Task OmitsModelFactoryMethodIfDerivedModelTypeInternal() + { + var baseModel = InputFactory.Model( + "baseModel", + properties: [InputFactory.Property("BaseProp", InputPrimitiveType.String)]); + var derivedModel = InputFactory.Model( + "derivedModel", + properties: [InputFactory.Property("DerivedProp", InputPrimitiveType.String)], + baseModel: baseModel); + + var mockGenerator = await MockHelpers.LoadMockGeneratorAsync( + inputModelTypes: [baseModel, derivedModel], + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + var csharpGen = new CSharpGen(); + + await csharpGen.ExecuteAsync(); + + var modelFactory = mockGenerator.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider); + Assert.IsNotNull(modelFactory); + CollectionAssert.Contains(modelFactory!.Methods.Select(m => m.Signature.Name), "BaseModel"); + CollectionAssert.DoesNotContain(modelFactory.Methods.Select(m => m.Signature.Name), "DerivedModel"); + } + [TestCase(true)] [TestCase(false)] public async Task CanCustomizeModelFullConstructor(bool extraParameters) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoryProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoryProviderTests.cs index 934239ba78f..45f9b0385b4 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoryProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoryProviderTests.cs @@ -975,61 +975,6 @@ public async Task BackCompatibility_BackCompatMethodAlreadyCustom() Assert.AreEqual(4, publicModel1Methods[0].Signature.Parameters.Count); } - // Back-compat members are synthesized in ProcessTypeForBackCompatibility, which runs after the - // main library visitor pass. This test ensures those newly-added members are still run through - // the registered visitors (only the new members, not the already-visited existing ones). - [Test] - public async Task BackCompatibility_BackCompatMethodIsVisited() - { - _instance = (await MockHelpers.LoadMockGeneratorAsync( - inputNamespaceName: "Sample.Namespace", - inputModelTypes: ModelList, - lastContractCompilation: async () => await Helpers.GetCompilationFromDirectoryAsync(method: "BackCompatibility_NewModelPropertyAdded"))).Object; - - var recordingVisitor = new RecordingMethodVisitor(); - _instance.AddVisitor(recordingVisitor); - - var modelFactory = _instance!.OutputLibrary.ModelFactory.Value; - modelFactory.ProcessTypeForBackCompatibility(); - - var backCompatMethod = modelFactory.Methods - .FirstOrDefault(m => m.Signature.Name == "PublicModel1" && m.Signature.Parameters.All(p => p.Name != "dictProp")); - Assert.IsNotNull(backCompatMethod, "Expected a back-compat overload to be synthesized."); - - // The synthesized back-compat method must have been visited. - Assert.IsTrue( - recordingVisitor.VisitedMethods.Contains(backCompatMethod!), - "The back-compat method was not visited by the library visitor."); - - // Existing methods that were already part of the contract are not re-visited by the visitor - // added after the main pass (they would have been visited during the main pass in a real run). - var currentOverloadMethod = modelFactory.Methods - .FirstOrDefault(m => m.Signature.Name == "PublicModel1" && m.Signature.Parameters.Any(p => p.Name == "dictProp")); - Assert.IsNotNull(currentOverloadMethod); - Assert.IsFalse(recordingVisitor.VisitedMethods.Contains(currentOverloadMethod!)); - } - - // Verifies that a visitor can mutate (rename) a synthesized back-compat method and the change is - // reflected in the final generated methods. - [Test] - public async Task BackCompatibility_BackCompatMethodCanBeMutatedByVisitor() - { - _instance = (await MockHelpers.LoadMockGeneratorAsync( - inputNamespaceName: "Sample.Namespace", - inputModelTypes: ModelList, - lastContractCompilation: async () => await Helpers.GetCompilationFromDirectoryAsync(method: "BackCompatibility_NewModelPropertyAdded"))).Object; - - _instance.AddVisitor(new BackCompatMethodRenamingVisitor()); - - var modelFactory = _instance!.OutputLibrary.ModelFactory.Value; - modelFactory.ProcessTypeForBackCompatibility(); - - // The visitor renames any method carrying the EditorBrowsableNever attribute (the back-compat - // overload) so the mutation must be observable on the final method collection. - var renamed = modelFactory.Methods.FirstOrDefault(m => m.Signature.Name == "PublicModel1Renamed"); - Assert.IsNotNull(renamed, "The visitor's rename of the back-compat method was not applied."); - } - private static InputModelType[] GetTestModels() { InputType additionalPropertiesUnknown = InputPrimitiveType.Any; @@ -1057,29 +1002,5 @@ private static InputModelType[] GetTestModels() InputFactory.Model("ModelWithUnknownAdditionalProperties", properties: properties, additionalProperties: additionalPropertiesUnknown), ]; } - - private class RecordingMethodVisitor : LibraryVisitor - { - public List VisitedMethods { get; } = []; - - protected internal override MethodProvider? VisitMethod(MethodProvider method) - { - VisitedMethods.Add(method); - return method; - } - } - - private class BackCompatMethodRenamingVisitor : LibraryVisitor - { - protected internal override MethodProvider? VisitMethod(MethodProvider method) - { - if (method.Signature.Name == "PublicModel1" - && method.Signature.Attributes.Any(a => a.ToDisplayString().Contains("EditorBrowsable"))) - { - method.Signature.Update(name: "PublicModel1Renamed"); - } - return method; - } - } } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs new file mode 100644 index 00000000000..bdb2034f5f0 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Sample.Models +{ + internal partial class DerivedModel + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ClientCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ClientCustomizationTests.cs index 7f3ea5cd1ff..d96e98a9d8e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ClientCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ClientCustomizationTests.cs @@ -405,7 +405,7 @@ private class ClientTypeProvider : TypeProvider public MethodProvider[] MethodProviders { get; set; } = []; public ConstructorProvider[] ConstructorProviders { get; set; } = []; - protected override string BuildRelativeFilePath() => "."; + protected override string BuildRelativeFilePath() => $"{Name}.cs"; protected override string BuildName() => "MockInputClient"; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs index 9f9945a2360..a30bd15aa29 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs @@ -1524,7 +1524,7 @@ await MockHelpers.LoadMockGeneratorAsync( } [Test] - public void PublicModelsAreIncludedInAdditionalRootTypes() + public void PublicModelsAreNotIncludedInAdditionalRootTypes() { var inputModel = InputFactory.Model( "MockInputModel", @@ -1537,7 +1537,7 @@ public void PublicModelsAreIncludedInAdditionalRootTypes() Assert.IsNotNull(modelProvider); var rootTypes = CodeModelGenerator.Instance.AdditionalRootTypes; - Assert.IsTrue(rootTypes.Contains("Sample.Models.MockInputModel")); + Assert.IsFalse(rootTypes.Contains("Sample.Models.MockInputModel")); } [Test] diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TestData/TypeSymbolExtensionsTests/TypeParameterDoesNotResolveContainingGenericType/GenericContainer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TestData/TypeSymbolExtensionsTests/TypeParameterDoesNotResolveContainingGenericType/GenericContainer.cs new file mode 100644 index 00000000000..27bd4cf11cd --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TestData/TypeSymbolExtensionsTests/TypeParameterDoesNotResolveContainingGenericType/GenericContainer.cs @@ -0,0 +1,6 @@ +namespace Sample +{ + public class GenericContainer + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TypeSymbolExtensionsTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TypeSymbolExtensionsTests.cs index 3884f8d584e..e8db4359802 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TypeSymbolExtensionsTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TypeSymbolExtensionsTests.cs @@ -83,6 +83,19 @@ public async Task NonNullableKnownFrameworkTypeResolvesUnchanged() Assert.IsFalse(csharpType.IsNullable); } + [Test] + public async Task TypeParameterDoesNotResolveContainingGenericType() + { + var compilation = await Helpers.GetCompilationFromDirectoryAsync(); + var typeSymbol = compilation.GetTypeByMetadataName("Sample.GenericContainer`1"); + Assert.IsNotNull(typeSymbol, "Failed to resolve generic type symbol from compiled source."); + + var csharpType = typeSymbol!.TypeParameters[0].GetCSharpType(); + + Assert.AreEqual("T", csharpType.Name); + Assert.IsNull(csharpType.DeclaringType); + } + private static IPropertySymbol GetPropertySymbol(Compilation compilation, string containerName, string propertyName) { var typeSymbol = compilation.GetTypeByMetadataName($"Sample.{containerName}"); diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/documentation/src/Generated/DocumentationModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/documentation/src/Generated/DocumentationModelFactory.cs index 8918ac1946c..4440dc1b6a0 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/documentation/src/Generated/DocumentationModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/documentation/src/Generated/DocumentationModelFactory.cs @@ -8,7 +8,6 @@ namespace Documentation { public static partial class DocumentationModelFactory { - public static BulletPointsModel BulletPointsModel(BulletPointsEnum prop = default) => throw null; } } diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/basic/src/Generated/ParametersBasicModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/basic/src/Generated/ParametersBasicModelFactory.cs index 06d44d34bc1..c17dda5eaec 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/basic/src/Generated/ParametersBasicModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/basic/src/Generated/ParametersBasicModelFactory.cs @@ -3,7 +3,6 @@ #nullable disable using Parameters.Basic._ExplicitBody; -using Parameters.Basic._ImplicitBody; namespace Parameters.Basic { diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/spread/src/Generated/ParametersSpreadModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/spread/src/Generated/ParametersSpreadModelFactory.cs index 775c933bc6b..4e494c2f2ec 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/spread/src/Generated/ParametersSpreadModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/spread/src/Generated/ParametersSpreadModelFactory.cs @@ -2,8 +2,6 @@ #nullable disable -using System.Collections.Generic; -using Parameters.Spread._Alias; using Parameters.Spread._Model; namespace Parameters.Spread diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/multipart/src/Generated/PayloadMultiPartModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/multipart/src/Generated/PayloadMultiPartModelFactory.cs index 6d036b01c86..e5c87b994a6 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/multipart/src/Generated/PayloadMultiPartModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/multipart/src/Generated/PayloadMultiPartModelFactory.cs @@ -2,7 +2,6 @@ #nullable disable -using System; using System.ClientModel; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/pageable/src/Generated/PayloadPageableModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/pageable/src/Generated/PayloadPageableModelFactory.cs index 0b66e798a05..1b7b8eefa6f 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/pageable/src/Generated/PayloadPageableModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/pageable/src/Generated/PayloadPageableModelFactory.cs @@ -2,18 +2,12 @@ #nullable disable -using System; -using System.Collections.Generic; -using Payload.Pageable._PageSize; -using Payload.Pageable._ServerDrivenPagination; using Payload.Pageable._ServerDrivenPagination.AlternateInitialVerb; -using Payload.Pageable._ServerDrivenPagination.ContinuationToken; namespace Payload.Pageable { public static partial class PayloadPageableModelFactory { - public static Pet Pet(string id = default, string name = default) => throw null; public static XmlPet XmlPet(string id = default, string name = default) => throw null; diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/special-words/src/Generated/SpecialWordsModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/special-words/src/Generated/SpecialWordsModelFactory.cs index 1abb2114c9a..43aca8db259 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/special-words/src/Generated/SpecialWordsModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/special-words/src/Generated/SpecialWordsModelFactory.cs @@ -2,10 +2,8 @@ #nullable disable -using System.Collections.Generic; using SpecialWords._ModelProperties; using SpecialWords._Models; -using SpecialWords._ReservedOperationBodyParams; namespace SpecialWords { From c63241063a7731ca246b0cf8ade27c1bce5c5fc5 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 2 Jul 2026 06:26:42 +0000 Subject: [PATCH 03/19] fix(csharp): preserve model union variants during pruning Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/PostProcessing/ProviderReferenceMapAnalyzer.cs | 5 +++-- .../test/Providers/EnumProviders/EnumProviderTests.cs | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs index 1058ba9d66b..1754e9c1d92 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs @@ -350,7 +350,7 @@ private static HashSet GetRemovalCandidates( helperRoots: [], includeModelFactory: true, includeAdditionalRoots: true, - includeUnionVariantRoots: false, + includeUnionVariantRoots: true, publicClientRootsOnly: false); removeRoots.UnionWith(customRemovalRoots); @@ -1504,7 +1504,8 @@ private static void AddUnionVariantRoots(HashSet roots, IReadOnlyList Date: Thu, 2 Jul 2026 06:49:48 +0000 Subject: [PATCH 04/19] fix(csharp): remove analyzer infrastructure workarounds Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../ProviderReferenceMapAnalyzer.cs | 139 ------------------ 1 file changed, 139 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs index 1754e9c1d92..a3c9e771efb 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs @@ -1080,7 +1080,6 @@ private static void AddGeneratedBodyReferences(IReadOnlyList provi GetNonEnumStructuredBodyReferenceTypes(provider, graph.Nodes), graph.Nodes); AddProviderBodyDependencyTypes(graph.References[providerName], provider.BodyDependencyTypes, graph.Nodes); - AddProviderInfrastructureReferences(graph.References[providerName], provider, graph.Nodes); AddHelperDependencies(graph.References[providerName], provider.HelperDependencyTypes, graph.Nodes, graph.References[providerName]); } } @@ -1099,144 +1098,6 @@ private static IReadOnlyList GetNonEnumStructuredBodyReferenceTypes( return references; } - private static void AddProviderInfrastructureReferences(HashSet references, TypeProvider provider, HashSet nodes) - { - AddMatchingName(references, "ProviderConstants", nodes); - AddMatchingName(references, "TypeFormatters", nodes); - - if (provider.SerializationProviders.Count > 0) - { - AddSerializationExtensionReferences(references, provider, nodes); - } - - if (IsSerializationProvider(provider)) - { - AddMatchingName(references, "Optional", nodes); - AddMatchingName(references, "Utf8JsonRequestContent", nodes); - AddMatchingName(references, "ModelSerializationExtensions", nodes); - AddSerializationExtensionReferences(references, provider, nodes); - } - - foreach (var method in provider.Methods) - { - if (method.IsMethodSuppressed()) - { - continue; - } - - AddMethodInfrastructureReferences(references, method, nodes); - } - } - - private static void AddSerializationExtensionReferences(HashSet references, TypeProvider provider, HashSet nodes) - { - AddSerializationExtensionReferences(references, provider.Type, nodes); - AddSerializationExtensionReferences(references, provider.BaseType, nodes); - foreach (var implementedType in provider.Implements) - { - AddSerializationExtensionReferences(references, implementedType, nodes); - } - - foreach (var property in provider.Properties) - { - AddSerializationExtensionReferences(references, property.Type, nodes); - } - - foreach (var field in provider.Fields) - { - AddSerializationExtensionReferences(references, field.Type, nodes); - } - - foreach (var constructor in provider.Constructors) - { - AddSerializationExtensionReferences(references, constructor.Signature.ReturnType, nodes); - foreach (var parameter in constructor.Signature.Parameters) - { - AddSerializationExtensionReferences(references, parameter.Type, nodes); - } - } - - foreach (var method in provider.Methods) - { - if (method.IsMethodSuppressed()) - { - continue; - } - - AddSerializationExtensionReferences(references, method.Signature.ReturnType, nodes); - foreach (var parameter in method.Signature.Parameters) - { - AddSerializationExtensionReferences(references, parameter.Type, nodes); - } - } - } - - private static void AddSerializationExtensionReferences(HashSet references, CSharpType? type, HashSet nodes) - { - if (type == null) - { - return; - } - - AddMatchingName(references, $"{type.Name}Extensions", nodes); - foreach (var argument in type.Arguments) - { - AddSerializationExtensionReferences(references, argument, nodes); - } - } - - private static void AddMethodInfrastructureReferences(HashSet references, MethodProvider method, HashSet nodes) - { - AddReturnTypeInfrastructureReferences(references, method.Signature.ReturnType, nodes); - } - - private static void AddReturnTypeInfrastructureReferences(HashSet references, CSharpType? returnType, HashSet nodes) - { - var type = UnwrapTask(returnType); - if (type == null) - { - return; - } - - var typeName = StripGenericArity(type.Name); - if (string.Equals(typeName, "Pageable", StringComparison.Ordinal)) - { - AddMatchingName(references, "PageableWrapper", nodes); - } - else if (string.Equals(typeName, "AsyncPageable", StringComparison.Ordinal)) - { - AddMatchingName(references, "AsyncPageableWrapper", nodes); - } - else if (string.Equals(typeName, "ArmOperation", StringComparison.Ordinal)) - { - AddMatchingNamesWithSimpleNameSuffix(references, "ArmOperation", nodes); - AddMatchingNamesWithSimpleNameSuffix(references, "OperationSource", nodes); - if (type.Arguments.Count > 0) - { - AddMatchingName(references, $"{BuildOperationSourceTypeName(type.Arguments[0])}OperationSource", nodes); - } - } - } - - private static CSharpType? UnwrapTask(CSharpType? type) - { - var typeName = type == null ? null : StripGenericArity(type.Name); - if ((string.Equals(typeName, "Task", StringComparison.Ordinal) || - string.Equals(typeName, "ValueTask", StringComparison.Ordinal)) && - type?.Arguments.Count > 0) - { - return type.Arguments[0]; - } - - return type; - } - - private static string BuildOperationSourceTypeName(CSharpType type) - { - var argumentNames = string.Join("", type.Arguments.Select(BuildOperationSourceTypeName)); - return $"{type.Name}{(argumentNames.Length > 0 ? "Of" : string.Empty)}{argumentNames}"; - } - private static IReadOnlyList CollectStructuredBodyReferenceTypes(TypeProvider provider) { var references = new HashSet(); From 98fe542fcf131bd324f1ce8f3cff31e974ba5f30 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 2 Jul 2026 07:56:56 +0000 Subject: [PATCH 05/19] fix(csharp): keep customized generated providers Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../Providers/CollectionResultDefinition.cs | 2 ++ ...ClientBodyDependencyPostProcessingTests.cs | 25 +++++++++++++++++++ .../ProviderReferenceMapAnalyzer.cs | 18 +++++++++---- 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs index 590eaf2b935..83d3ae8444b 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs @@ -233,6 +233,8 @@ protected override IReadOnlyList BuildBodyDependencyTypes() return dependencies; } + protected override IReadOnlyList BuildHelperDependencyTypes() => [new ClientPipelineExtensionsDefinition().Type]; + protected override FieldProvider[] BuildFields() => [ClientField, .. RequestFields]; protected override CSharpType[] BuildImplements() => diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs index e5a47dbc62e..e09ea5978c6 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs @@ -67,6 +67,31 @@ await GenerateAndAssertFiles( ]); } + [Test] + public async Task EmptyCustomPartialModelIsKept() + { + var customizedModel = InputFactory.Model("CustomizedModel"); + + await GenerateAndAssertFiles( + enums: [], + models: [customizedModel], + clients: [], + customFiles: [( + Path.Combine("src", "CustomizedModel.cs"), + """ + namespace Sample.Models + { + public partial class CustomizedModel + { + } + } + """)], + expectedFiles: [ + Path.Combine("src", "Generated", "Models", "CustomizedModel.cs"), + Path.Combine("src", "Generated", "Models", "CustomizedModel.Serialization.cs") + ]); + } + [Test] public async Task InternalAdditionalRootModelIsRemovedWhenNotOtherwiseReferenced() { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs index a3c9e771efb..4e23e44f888 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs @@ -381,6 +381,7 @@ private static HashSet GetCustomCodeGeneratedTypeRoots(IReadOnlyList(StringComparer.Ordinal); foreach (var customCodeView in GetCustomCodeViews(providers)) { + AddCustomCodeViewGeneratedTypeRoot(roots, customCodeView, generatedTypeNames); AddCustomCodeViewRoots(roots, customCodeView, generatedTypeNames, publicOnly: false); } @@ -397,11 +398,7 @@ private static HashSet GetCustomCodePublicGeneratedTypeRoots(IReadOnlyLi continue; } - if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) - { - AddMatchingName(roots, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); - } - + AddCustomCodeViewGeneratedTypeRoot(roots, customCodeView, generatedTypeNames); AddCustomCodeViewRoots(roots, customCodeView, generatedTypeNames, publicOnly: true); } @@ -467,6 +464,17 @@ customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider ? namedTypeSymbolProvider.MetadataSimpleName : customCodeView.Type.Name; + private static void AddCustomCodeViewGeneratedTypeRoot(HashSet roots, TypeProvider customCodeView, HashSet generatedTypeNames) + { + if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) + { + AddMatchingName(roots, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); + return; + } + + AddTypeReference(roots, customCodeView.Type, generatedTypeNames); + } + private static void AddCustomizationBackedExtensionRoots(HashSet roots, HashSet nodes) { foreach (var node in nodes) From 94fcdfccced4581620eee409def94a88827badc8 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 2 Jul 2026 08:36:55 +0000 Subject: [PATCH 06/19] fix(csharp): keep pageable pipeline extension helper Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/Providers/CollectionResultDefinition.cs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs index 83d3ae8444b..bca6dd41e4d 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs @@ -219,7 +219,7 @@ protected override TypeSignatureModifiers BuildDeclarationModifiers() protected override IReadOnlyList BuildBodyDependencyTypes() { - var dependencies = new List { Client.Type, ResponseModelType, NextPagePropertyType }; + var dependencies = new List { Client.Type, ResponseModelType, NextPagePropertyType, new ClientPipelineExtensionsDefinition().Type }; if (ItemModelType != null) { dependencies.Add(ItemModelType); @@ -233,8 +233,6 @@ protected override IReadOnlyList BuildBodyDependencyTypes() return dependencies; } - protected override IReadOnlyList BuildHelperDependencyTypes() => [new ClientPipelineExtensionsDefinition().Type]; - protected override FieldProvider[] BuildFields() => [ClientField, .. RequestFields]; protected override CSharpType[] BuildImplements() => From d80796c27292530c0a468e19ee7b2559913ae068 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 2 Jul 2026 10:03:33 +0000 Subject: [PATCH 07/19] fix(csharp): publicize generated public dependencies Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../ClientBodyDependencyPostProcessingTests.cs | 17 +++++++++++++++++ .../ProviderReferenceMapAnalyzer.cs | 7 ++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs index e09ea5978c6..e00c2ec0388 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs @@ -39,6 +39,23 @@ public async Task OperationResponseBodyModelRemainsPublicAsRootOutputModel() await GenerateAndAssertPublicModels([responseModel], [client], ["ResponseBody"]); } + [Test] + public async Task InternalModelReferencedByPublicModelPropertyIsPublicized() + { + var dependencyModel = InputFactory.Model("DependencyModel", access: "internal"); + var responseModel = InputFactory.Model( + "ResponseBody", + properties: [InputFactory.Property("Dependency", dependencyModel)]); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(responseModel, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertPublicModels([responseModel, dependencyModel], [client], ["ResponseBody", "DependencyModel"]); + } + [Test] public async Task OperationResponseBodyModelIsRemovedWhenNotOtherwiseReferenced() { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs index 4e23e44f888..8a11636a1dd 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs @@ -174,6 +174,7 @@ private static (HashSet InternalizeCandidates, HashSet Publicize AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, internalizeReachableWithoutHelpers, generatedDiscriminatorBaseNames); internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); var publicizeRoots = new HashSet(internalizeRoots, StringComparer.Ordinal); + var publicApiReferences = CloneReferences(publicGraph.References); var internalizeHelperRoots = GetHelperRootNames(generatedProviders, graph.Nodes, internalizeReachableWithoutHelpers); internalizeRoots.UnionWith(internalizeHelperRoots); var internalizeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, graph.Nodes, publicOnly: true); @@ -209,6 +210,7 @@ private static (HashSet InternalizeCandidates, HashSet Publicize publicizeRootExclusions, generatedInternalDeclarations, publicizeRoots, + publicApiReferences, internalizeReferences, generatedImplementationInternalDeclarations); @@ -305,6 +307,7 @@ private static HashSet GetPublicizeCandidates( HashSet publicizeRootExclusions, HashSet generatedInternalDeclarations, HashSet publicizeRoots, + Dictionary> publicApiReferences, Dictionary> internalizeReferences, HashSet generatedImplementationInternalDeclarations) { @@ -320,7 +323,9 @@ private static HashSet GetPublicizeCandidates( continue; } - if (generatedInternalDeclarations.Contains(node) && !publicizeRoots.Contains(node)) + if (generatedInternalDeclarations.Contains(node) && + !publicizeRoots.Contains(node) && + !HasPublicApiPredecessor(node, publicApiReferences, publicizeReachable, generatedImplementationInternalDeclarations)) { continue; } From 1c1c0b731144263095f730b6cc74eeb99a1c6c80 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 2 Jul 2026 11:23:53 +0000 Subject: [PATCH 08/19] Fix provider reference map edge cases Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../ProviderReferenceMapAnalyzer.cs | 24 +++++++++-- .../src/Utilities/TypeSymbolExtensions.cs | 43 +++++++++++++++---- 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs index 8a11636a1dd..2ce4b2b6631 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs @@ -53,7 +53,7 @@ public static void ApplyPreWriteAccessibility(IReadOnlyList provid provider.PreserveXmlDocs(); provider.Update(modifiers: MakeInternal(provider.DeclarationModifiers)); } - else if (publicizeCandidates.Contains(providerName)) + else if (publicizeCandidates.Contains(providerName) && !IsGeneratedInternalImplementation(provider)) { provider.Update(modifiers: MakePublic(provider.DeclarationModifiers)); } @@ -180,7 +180,7 @@ private static (HashSet InternalizeCandidates, HashSet Publicize var internalizeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, graph.Nodes, publicOnly: true); var customInternalBoundaryNodes = GetCustomInternalBoundaryNodes(publicGraph, customInternalDeclarations); var publicizeDeclaredNodes = GetPublicizeDeclaredNodes(generatedProviders, graph.Nodes, internalizeDeclaredNodes); - var generatedImplementationInternalDeclarations = GetGeneratedImplementationInternalTypeDeclarations(generatedInternalDeclarations); + var generatedImplementationInternalDeclarations = GetGeneratedImplementationInternalTypeDeclarations(generatedProviders, generatedInternalDeclarations); var publicApiTraversalNodes = GetPublicApiTraversalNodes( internalizeDeclaredNodes, publicizeDeclaredNodes, @@ -510,6 +510,7 @@ private static void AddCustomCodeViewRoots(HashSet roots, TypeProvider c AddProviderBodyDependencyTypes(roots, customCodeView.SignatureDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true); if (!publicOnly) { + AddProviderBodyDependencyTypes(roots, customCodeView.BodyDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true); AddAttributes(roots, customCodeView.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); AddMatchingName(roots, $"{GetCustomCodeViewSimpleName(customCodeView)}Extensions", generatedTypeNames); } @@ -702,9 +703,21 @@ private static HashSet GetGeneratedTypeDeclarationsByLastContractAccessi return declarations; } - private static HashSet GetGeneratedImplementationInternalTypeDeclarations(HashSet generatedInternalDeclarations) + private static HashSet GetGeneratedImplementationInternalTypeDeclarations( + IReadOnlyList providers, + HashSet generatedInternalDeclarations) { var implementationDeclarations = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (!provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) + { + continue; + } + + AddTypeReference(implementationDeclarations, provider.Type, generatedInternalDeclarations); + } + foreach (var name in generatedInternalDeclarations) { if (GetSimpleName(name).StartsWith("Internal", StringComparison.Ordinal)) @@ -716,6 +729,11 @@ private static HashSet GetGeneratedImplementationInternalTypeDeclaration return implementationDeclarations; } + private static bool IsGeneratedInternalImplementation(TypeProvider provider) + => provider.RelativeFilePath.Contains( + $"{Path.DirectorySeparatorChar}Generated{Path.DirectorySeparatorChar}Internal{Path.DirectorySeparatorChar}", + StringComparison.Ordinal); + private static HashSet GetSimpleNames(HashSet names) { var simpleNames = new HashSet(StringComparer.Ordinal); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Utilities/TypeSymbolExtensions.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Utilities/TypeSymbolExtensions.cs index bf8cc5eee38..c7915847be6 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Utilities/TypeSymbolExtensions.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Utilities/TypeSymbolExtensions.cs @@ -43,6 +43,9 @@ public static bool IsSameType(this INamedTypeSymbol symbol, CSharpType type) } public static CSharpType GetCSharpType(this ITypeSymbol typeSymbol) + => GetCSharpType(typeSymbol, new HashSet(SymbolEqualityComparer.Default)); + + private static CSharpType GetCSharpType(this ITypeSymbol typeSymbol, HashSet visited) { var fullyQualifiedName = GetFullyQualifiedName(typeSymbol); var namedTypeSymbol = typeSymbol as INamedTypeSymbol; @@ -55,20 +58,20 @@ public static CSharpType GetCSharpType(this ITypeSymbol typeSymbol) if (namedTypeSymbol?.ConstructedFrom.SpecialType == SpecialType.System_Nullable_T && namedTypeSymbol.TypeArguments.Length == 1) { - var underlying = GetCSharpType(namedTypeSymbol.TypeArguments[0]); + var underlying = GetCSharpType(namedTypeSymbol.TypeArguments[0], visited); if (underlying.IsFrameworkType) { return underlying.WithNullable(true); } } - return ConstructCSharpTypeFromSymbol(typeSymbol, fullyQualifiedName, namedTypeSymbol); + return ConstructCSharpTypeFromSymbol(typeSymbol, fullyQualifiedName, namedTypeSymbol, visited); } CSharpType result = new CSharpType(type); if (namedTypeSymbol is not null && namedTypeSymbol.IsGenericType && !result.IsNullable) { - return result.MakeGenericType([.. namedTypeSymbol.TypeArguments.Select(GetCSharpType)]); + return result.MakeGenericType([.. namedTypeSymbol.TypeArguments.Select(t => GetCSharpType(t, visited))]); } return result; @@ -175,8 +178,14 @@ public static string GetFullyQualifiedNameFromDisplayString(this ISymbol typeSym private static CSharpType ConstructCSharpTypeFromSymbol( ITypeSymbol typeSymbol, string fullyQualifiedName, - INamedTypeSymbol? namedTypeSymbol) + INamedTypeSymbol? namedTypeSymbol, + HashSet visited) { + if (!visited.Add(typeSymbol)) + { + return ConstructShallowCSharpTypeFromSymbol(typeSymbol, fullyQualifiedName); + } + var typeArg = namedTypeSymbol?.TypeArguments.FirstOrDefault(); bool isValueType = typeSymbol.IsValueType; bool isNullable = fullyQualifiedName.StartsWith(NullableTypeName); @@ -193,7 +202,7 @@ private static CSharpType ConstructCSharpTypeFromSymbol( if (namedTypeSymbol?.IsGenericType == true && (!isNullable || (namedTypeArg?.IsGenericType == true))) { - arguments.AddRange(namedTypeSymbol.TypeArguments.Select(GetCSharpType)); + arguments.AddRange(namedTypeSymbol.TypeArguments.Select(t => GetCSharpType(t, visited))); } // handle nullables @@ -209,7 +218,7 @@ private static CSharpType ConstructCSharpTypeFromSymbol( if (typeSymbol.ContainingType != null && typeSymbol.TypeKind != TypeKind.TypeParameter) { - containingType = GetCSharpType(typeSymbol.ContainingType); + containingType = GetCSharpType(typeSymbol.ContainingType, visited); ns = string.Join('.', pieces.Take(pieces.Length - 2)); } @@ -219,10 +228,10 @@ private static CSharpType ConstructCSharpTypeFromSymbol( !isNullableUnknownType && !ContainsTypeAsArgument(typeSymbol.BaseType, typeSymbol)) { - baseType = GetCSharpType(typeSymbol.BaseType); + baseType = GetCSharpType(typeSymbol.BaseType, visited); } - return new CSharpType( + var result = new CSharpType( name, ns, isValueType, @@ -233,8 +242,24 @@ private static CSharpType ConstructCSharpTypeFromSymbol( isValueType && !isEnum, baseType: baseType, underlyingEnumType: enumUnderlyingType != null - ? GetCSharpType(enumUnderlyingType).FrameworkType + ? GetCSharpType(enumUnderlyingType, visited).FrameworkType : null); + visited.Remove(typeSymbol); + return result; + } + + private static CSharpType ConstructShallowCSharpTypeFromSymbol(ITypeSymbol typeSymbol, string fullyQualifiedName) + { + string[] pieces = fullyQualifiedName.Split('`')[0].Split('.'); + return new CSharpType( + typeSymbol.Name, + string.Join('.', pieces.Take(pieces.Length - 1)), + typeSymbol.IsValueType, + fullyQualifiedName.StartsWith(NullableTypeName), + null, + [], + typeSymbol.DeclaredAccessibility == Accessibility.Public, + typeSymbol.IsValueType && typeSymbol.TypeKind != TypeKind.Enum); } internal static bool ContainsTypeAsArgument(ITypeSymbol potentialGenericType, ITypeSymbol targetType) From a264f2ee4e595d688aafd06e2d46a501d1ca2e50 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 2 Jul 2026 14:31:28 +0000 Subject: [PATCH 09/19] fix(http-client-csharp): improve provider reference map analysis Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- ...ybrid-reference-map-2026-07-02-14-20-08.md | 7 + ...ClientBodyDependencyPostProcessingTests.cs | 105 + .../ProviderReferenceMapAnalyzer.cs | 2093 ----------------- .../Primitives/PropertyDescriptionBuilder.cs | 2 +- ...iderReferenceMapAnalyzer.BodyReferences.cs | 254 ++ ...ProviderReferenceMapAnalyzer.Candidates.cs | 199 ++ ...derReferenceMapAnalyzer.CustomCodeRoots.cs | 362 +++ ...oviderReferenceMapAnalyzer.GraphBuilder.cs | 175 ++ .../ProviderReferenceMapAnalyzer.Helpers.cs | 484 ++++ ...ReferenceMapAnalyzer.ReferenceTraversal.cs | 188 ++ ...viderReferenceMapAnalyzer.RootSelection.cs | 164 ++ ...renceMapAnalyzer.TypeReferenceCollector.cs | 172 ++ .../ProviderReferenceMapAnalyzer.cs | 242 ++ 13 files changed, 2353 insertions(+), 2094 deletions(-) create mode 100644 .chronus/changes/mtg-hybrid-reference-map-2026-07-02-14-20-08.md delete mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Candidates.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.GraphBuilder.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.ReferenceTraversal.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.RootSelection.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.TypeReferenceCollector.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs diff --git a/.chronus/changes/mtg-hybrid-reference-map-2026-07-02-14-20-08.md b/.chronus/changes/mtg-hybrid-reference-map-2026-07-02-14-20-08.md new file mode 100644 index 00000000000..533726f69bf --- /dev/null +++ b/.chronus/changes/mtg-hybrid-reference-map-2026-07-02-14-20-08.md @@ -0,0 +1,7 @@ +--- +changeKind: fix +packages: + - "@typespec/http-client-csharp" +--- + +Improve generated C# reference-map analysis so provider accessibility and XML documentation stay consistent. diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs index e00c2ec0388..2cadcbe5cb0 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Threading.Tasks; using Microsoft.TypeSpec.Generator.Input; +using Microsoft.TypeSpec.Generator.Primitives; using Microsoft.TypeSpec.Generator.Tests.Common; using NUnit.Framework; @@ -56,6 +57,110 @@ public async Task InternalModelReferencedByPublicModelPropertyIsPublicized() await GenerateAndAssertPublicModels([responseModel, dependencyModel], [client], ["ResponseBody", "DependencyModel"]); } + [Test] + public async Task InternalModelReferencedByPublicNonRootCollectionPropertyIsPublicized() + { + var dependencyModel = InputFactory.Model("DependencyModel", access: "internal"); + var responseModel = InputFactory.Model( + "ResponseBody", + properties: [InputFactory.Property("Dependencies", InputFactory.Array(dependencyModel))]); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(responseModel, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [responseModel, dependencyModel], + clients: [client], + customFiles: [ + (Path.Combine("src", "Generated", "SampleModelFactory.cs"), """ + using System.Collections.Generic; + using Sample.Models; + + namespace Sample; + + public static partial class SampleModelFactory + { + public static ResponseBody ResponseBody(IEnumerable dependencies = default) => null; + } + """) + ], + expectedFiles: [], + publicModelNames: ["ResponseBody", "DependencyModel"], + configureGenerator: () => + { + var responseProvider = CodeModelGenerator.Instance.OutputLibrary.TypeProviders.Single(provider => provider.Name == "ResponseBody"); + var dependencyProvider = CodeModelGenerator.Instance.OutputLibrary.TypeProviders.Single(provider => provider.Name == "DependencyModel"); + CodeModelGenerator.Instance.AddTypeToKeep(responseProvider, isRoot: false); + CodeModelGenerator.Instance.AddTypeToKeep(dependencyProvider, isRoot: false); + }); + } + + [Test] + public async Task CustomInternalBoundaryInternalizesPublicNonRootModel() + { + var customInternalModel = InputFactory.Model("CustomInternalModel"); + var publicWrapper = InputFactory.Model( + "PublicWrapper", + properties: [InputFactory.Property("CustomInternal", customInternalModel)]); + + var outputPath = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + Directory.CreateDirectory(outputPath); + try + { + var customPath = Path.Combine(outputPath, "src", "Custom", "CustomInternalModel.cs"); + Directory.CreateDirectory(Path.GetDirectoryName(customPath)!); + File.WriteAllText(customPath, """ + using Microsoft.TypeSpec.Generator.Customizations; + + namespace Sample.Models; + + [CodeGenType("CustomInternalModel")] + internal partial class CustomInternalModel + { + } + """); + var modelFactoryPath = Path.Combine(outputPath, "src", "Generated", "SampleModelFactory.cs"); + Directory.CreateDirectory(Path.GetDirectoryName(modelFactoryPath)!); + File.WriteAllText(modelFactoryPath, """ + using Sample.Models; + + namespace Sample; + + public static partial class SampleModelFactory + { + public static PublicWrapper PublicWrapper(CustomInternalModel customInternal = default) => null; + } + """); + + await MockHelpers.LoadMockGeneratorAsync( + inputModels: () => [publicWrapper, customInternalModel], + configuration: """{ "package-name": "Sample", "disable-xml-docs": true }""", + outputPath: outputPath); + + var publicWrapperProvider = CodeModelGenerator.Instance.OutputLibrary.TypeProviders.Single(provider => provider.Name == "PublicWrapper"); + var customInternalProvider = CodeModelGenerator.Instance.OutputLibrary.TypeProviders.Single(provider => provider.Name == "CustomInternalModel"); + CodeModelGenerator.Instance.AddTypeToKeep(publicWrapperProvider, isRoot: false); + CodeModelGenerator.Instance.AddTypeToKeep(customInternalProvider, isRoot: false); + + ProviderReferenceMapAnalyzer.ApplyPreWriteAccessibility(CodeModelGenerator.Instance.OutputLibrary.TypeProviders); + + Assert.IsTrue(publicWrapperProvider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal), "PublicWrapper should be internalized when its public surface exposes a custom/internal type."); + Assert.IsTrue(customInternalProvider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal), "CustomInternalModel should remain internal."); + } + finally + { + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); + if (Directory.Exists(outputPath)) + { + Directory.Delete(outputPath, recursive: true); + } + } + } + [Test] public async Task OperationResponseBodyModelIsRemovedWhenNotOtherwiseReferenced() { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs deleted file mode 100644 index 8a11636a1dd..00000000000 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapAnalyzer.cs +++ /dev/null @@ -1,2093 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System; -using System.Collections; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text.RegularExpressions; -using Microsoft.TypeSpec.Generator.Expressions; -using Microsoft.TypeSpec.Generator.Primitives; -using Microsoft.TypeSpec.Generator.Providers; -using Microsoft.TypeSpec.Generator.Statements; - -namespace Microsoft.TypeSpec.Generator -{ - internal static class ProviderReferenceMapAnalyzer - { - private static ProviderReferenceMapResult? _latestResult; - private static readonly ConditionalWeakTable, Dictionary> _simpleNameLookupCache = new(); - private static TypeProvider? _preWriteModelFactory; - private static MethodProvider[]? _preWriteModelFactoryMethods; - - public static ProviderReferenceMapResult? LatestResult => _latestResult; - public static bool PreWriteAccessibilityApplied { get; private set; } - - public static bool ShouldWriteProvider(TypeProvider provider) => - _latestResult?.RemoveCandidates.Contains(GetProviderTypeName(provider.Type)) != true; - - public static void ResetPreWriteAccessibility() - { - RestorePreWriteModelFactoryMethods(); - _latestResult = null; - PreWriteAccessibilityApplied = false; - } - - public static void ApplyPreWriteAccessibility(IReadOnlyList providers) - { - PreWriteAccessibilityApplied = false; - if (Configuration.UnreferencedTypesHandling == Configuration.UnreferencedTypesHandlingOption.KeepAll) - { - return; - } - - var (internalizeCandidates, publicizeCandidates) = GetPreWriteAccessibilityCandidates(providers); - foreach (var provider in GetGeneratedProviders(providers)) - { - var providerName = GetProviderTypeName(provider.Type); - if (internalizeCandidates.Contains(providerName)) - { - provider.PreserveXmlDocs(); - provider.Update(modifiers: MakeInternal(provider.DeclarationModifiers)); - } - else if (publicizeCandidates.Contains(providerName)) - { - provider.Update(modifiers: MakePublic(provider.DeclarationModifiers)); - } - } - - RemoveMethodsFromModelFactory(GetSimpleNames(internalizeCandidates)); - PreWriteAccessibilityApplied = true; - } - - public static void RestorePreWriteModelFactoryMethods() - { - if (_preWriteModelFactory == null || _preWriteModelFactoryMethods == null) - { - return; - } - - _preWriteModelFactory.Update(methods: _preWriteModelFactoryMethods); - _preWriteModelFactory = null; - _preWriteModelFactoryMethods = null; - } - - public static void Analyze(IReadOnlyList providers) - { - var generatedProviders = GetGeneratedProviders(providers); - var graph = BuildGraph(generatedProviders); - var publicGraph = BuildGraph(generatedProviders, publicOnly: true); - - var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(generatedProviders, graph.Nodes); - var apiBaselineGeneratedTypeRoots = GetApiBaselineGeneratedTypeRoots(graph.Nodes); - customPublicRoots.UnionWith(apiBaselineGeneratedTypeRoots); - var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarations(generatedProviders, graph.Nodes); - customPublicRoots.UnionWith(generatedPublicDeclarations); - var customCodeRemovalRoots = GetCustomCodeGeneratedTypeRoots(generatedProviders, graph.Nodes); - var customRemovalRoots = new HashSet(customCodeRemovalRoots, StringComparer.Ordinal); - customRemovalRoots.UnionWith(apiBaselineGeneratedTypeRoots); - customRemovalRoots.UnionWith(generatedPublicDeclarations); - var customInternalDeclarations = GetCustomCodeInternalGeneratedTypeDeclarations(generatedProviders, graph.Nodes); - var generatedInternalDeclarations = GetGeneratedInternalTypeDeclarations(generatedProviders, graph.Nodes); - - // Helper types are rooted after an initial reachability pass so unused infrastructure - // such as change-tracking dictionaries can still be removed when no reachable type needs them. - var generatedDiscriminatorBaseNames = GetGeneratedPersistableModelProxyTypeNames(generatedProviders, publicGraph.Nodes); - var (internalizeCandidates, publicizeCandidates, _) = GetAccessibilityCandidates( - providers, - generatedProviders, - graph, - publicGraph, - customPublicRoots, - customInternalDeclarations, - generatedInternalDeclarations, - generatedDiscriminatorBaseNames); - - // Body-only generated dependencies are needed to avoid deleting helper files, but they do - // not contribute to public API reachability for internalization. - AddGeneratedBodyReferences(providers, graph); - var removeCandidates = GetRemovalCandidates( - providers, - generatedProviders, - graph, - customRemovalRoots, - generatedDiscriminatorBaseNames); - - _latestResult = new ProviderReferenceMapResult( - internalizeCandidates, - publicizeCandidates, - removeCandidates); - RemoveMethodsFromModelFactory(GetSimpleNames(removeCandidates)); - } - - private static (HashSet InternalizeCandidates, HashSet PublicizeCandidates) GetPreWriteAccessibilityCandidates(IReadOnlyList providers) - { - var generatedProviders = GetGeneratedProviders(providers); - var graph = BuildGraph(generatedProviders); - var publicGraph = BuildGraph(generatedProviders, publicOnly: true); - var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(generatedProviders, graph.Nodes); - var apiBaselineGeneratedTypeRoots = GetApiBaselineGeneratedTypeRoots(graph.Nodes); - customPublicRoots.UnionWith(apiBaselineGeneratedTypeRoots); - var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarations(generatedProviders, graph.Nodes); - customPublicRoots.UnionWith(generatedPublicDeclarations); - var customInternalDeclarations = GetCustomCodeInternalGeneratedTypeDeclarations(generatedProviders, graph.Nodes); - var generatedInternalDeclarations = GetGeneratedInternalTypeDeclarations(generatedProviders, graph.Nodes); - var generatedDiscriminatorBaseNames = new HashSet(StringComparer.Ordinal); - - var (internalizeCandidates, publicizeCandidates, _) = GetAccessibilityCandidates( - providers, - generatedProviders, - graph, - publicGraph, - customPublicRoots, - customInternalDeclarations, - generatedInternalDeclarations, - generatedDiscriminatorBaseNames); - - return (internalizeCandidates, publicizeCandidates); - } - - private static (HashSet InternalizeCandidates, HashSet PublicizeCandidates, HashSet InternalizeHelperRoots) GetAccessibilityCandidates( - IReadOnlyList providers, - IReadOnlyList generatedProviders, - ProviderReferenceGraph graph, - ProviderReferenceGraph publicGraph, - HashSet customPublicRoots, - HashSet customInternalDeclarations, - HashSet generatedInternalDeclarations, - HashSet generatedDiscriminatorBaseNames) - { - var internalizeReferences = CloneReferences(publicGraph.References); - var internalizeRoots = GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: false, includeAdditionalRoots: true, includeUnionVariantRoots: false, publicClientRootsOnly: true); - if (ShouldUseUnionVariantFallbackRoots()) - { - AddUnionVariantRoots(internalizeRoots, providers, graph.Nodes); - } - - var generatedPublicReachable = GetReachableTypes(internalizeRoots, internalizeReferences); - AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, generatedPublicReachable, generatedDiscriminatorBaseNames); - internalizeRoots.UnionWith(customPublicRoots); - var internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); - AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, internalizeReachableWithoutHelpers, generatedDiscriminatorBaseNames); - internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); - var publicizeRoots = new HashSet(internalizeRoots, StringComparer.Ordinal); - var publicApiReferences = CloneReferences(publicGraph.References); - var internalizeHelperRoots = GetHelperRootNames(generatedProviders, graph.Nodes, internalizeReachableWithoutHelpers); - internalizeRoots.UnionWith(internalizeHelperRoots); - var internalizeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, graph.Nodes, publicOnly: true); - var customInternalBoundaryNodes = GetCustomInternalBoundaryNodes(publicGraph, customInternalDeclarations); - var publicizeDeclaredNodes = GetPublicizeDeclaredNodes(generatedProviders, graph.Nodes, internalizeDeclaredNodes); - var generatedImplementationInternalDeclarations = GetGeneratedImplementationInternalTypeDeclarations(generatedInternalDeclarations); - var publicApiTraversalNodes = GetPublicApiTraversalNodes( - internalizeDeclaredNodes, - publicizeDeclaredNodes, - generatedInternalDeclarations, - generatedImplementationInternalDeclarations); - var publicizeReachable = GetReachableTypes(publicizeRoots, internalizeReferences, publicApiTraversalNodes); - var internalizeCandidates = GetInternalizeCandidates( - internalizeDeclaredNodes, - publicizeReachable, - customInternalDeclarations, - customInternalBoundaryNodes, - publicizeRoots); - var publicizeRootExclusions = GetRootNames( - providers, - graph.Nodes, - helperRoots: [], - includeModelFactory: true, - includeAdditionalRoots: true, - includeUnionVariantRoots: true, - publicClientRootsOnly: true); - var publicizeCandidates = GetPublicizeCandidates( - publicizeDeclaredNodes, - publicizeReachable, - customInternalDeclarations, - customInternalBoundaryNodes, - internalizeHelperRoots, - publicizeRootExclusions, - generatedInternalDeclarations, - publicizeRoots, - publicApiReferences, - internalizeReferences, - generatedImplementationInternalDeclarations); - - return (internalizeCandidates, publicizeCandidates, internalizeHelperRoots); - } - - private static HashSet GetCustomInternalBoundaryNodes( - ProviderReferenceGraph publicGraph, - HashSet customInternalDeclarations) - { - var boundaryNodes = new HashSet(StringComparer.Ordinal); - foreach (var node in publicGraph.Nodes) - { - if (!publicGraph.References.TryGetValue(node, out var references)) - { - continue; - } - - if (references.Overlaps(customInternalDeclarations)) - { - boundaryNodes.Add(node); - } - } - - return boundaryNodes; - } - - private static HashSet GetPublicizeDeclaredNodes( - IReadOnlyList generatedProviders, - HashSet nodes, - HashSet internalizeDeclaredNodes) - { - var publicizeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, nodes, publicOnly: false); - publicizeDeclaredNodes.ExceptWith(internalizeDeclaredNodes); - return publicizeDeclaredNodes; - } - - private static HashSet GetPublicApiTraversalNodes( - HashSet internalizeDeclaredNodes, - HashSet publicizeDeclaredNodes, - HashSet generatedInternalDeclarations, - HashSet generatedImplementationInternalDeclarations) - { - var traversalNodes = new HashSet(StringComparer.Ordinal); - foreach (var node in internalizeDeclaredNodes) - { - if (generatedInternalDeclarations.Contains(node) || - generatedImplementationInternalDeclarations.Contains(node)) - { - continue; - } - - traversalNodes.Add(node); - } - - foreach (var node in publicizeDeclaredNodes) - { - if (!generatedImplementationInternalDeclarations.Contains(node)) - { - traversalNodes.Add(node); - } - } - - return traversalNodes; - } - - private static HashSet GetInternalizeCandidates( - HashSet internalizeDeclaredNodes, - HashSet publicizeReachable, - HashSet customInternalDeclarations, - HashSet customInternalBoundaryNodes, - HashSet publicizeRoots) - { - var candidates = new HashSet(StringComparer.Ordinal); - foreach (var node in internalizeDeclaredNodes) - { - if (!publicizeReachable.Contains(node) || - customInternalDeclarations.Contains(node) || - customInternalBoundaryNodes.Contains(node) && !publicizeRoots.Contains(node)) - { - candidates.Add(node); - } - } - - return candidates; - } - - private static HashSet GetPublicizeCandidates( - HashSet publicizeDeclaredNodes, - HashSet publicizeReachable, - HashSet customInternalDeclarations, - HashSet customInternalBoundaryNodes, - HashSet internalizeHelperRoots, - HashSet publicizeRootExclusions, - HashSet generatedInternalDeclarations, - HashSet publicizeRoots, - Dictionary> publicApiReferences, - Dictionary> internalizeReferences, - HashSet generatedImplementationInternalDeclarations) - { - var candidates = new HashSet(StringComparer.Ordinal); - foreach (var node in publicizeDeclaredNodes) - { - if (customInternalDeclarations.Contains(node) || - customInternalBoundaryNodes.Contains(node) || - internalizeHelperRoots.Contains(node) || - publicizeRootExclusions.Contains(node) || - !publicizeReachable.Contains(node)) - { - continue; - } - - if (generatedInternalDeclarations.Contains(node) && - !publicizeRoots.Contains(node) && - !HasPublicApiPredecessor(node, publicApiReferences, publicizeReachable, generatedImplementationInternalDeclarations)) - { - continue; - } - - if (!publicizeRoots.Contains(node) && - !HasPublicApiPredecessor(node, internalizeReferences, publicizeReachable, generatedImplementationInternalDeclarations)) - { - continue; - } - - candidates.Add(node); - } - - return candidates; - } - - private static HashSet GetRemovalCandidates( - IReadOnlyList providers, - IReadOnlyList generatedProviders, - ProviderReferenceGraph graph, - HashSet customRemovalRoots, - HashSet generatedDiscriminatorBaseNames) - { - var removeRoots = GetRootNames( - providers, - graph.Nodes, - helperRoots: [], - includeModelFactory: true, - includeAdditionalRoots: true, - includeUnionVariantRoots: true, - publicClientRootsOnly: false); - - removeRoots.UnionWith(customRemovalRoots); - AddMatchingNamesWithSimpleNameSuffix(removeRoots, "ReferenceType", graph.Nodes); - AddCustomCodeExtensionRoots(removeRoots, generatedProviders, graph.Nodes); - AddCustomizationBackedExtensionRoots(removeRoots, graph.Nodes); - AddCustomRequestHeaderExtensionsRoot(removeRoots, generatedProviders, graph.Nodes); - RemoveUnusedRequestHeaderExtensionsRoot(removeRoots, graph.References, providers); - - var removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); - AddDerivedModelReferences(providers, graph.Nodes, graph.References, removeReachableWithoutHelpers, generatedDiscriminatorBaseNames); - removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); - AddBasePreservedReferences(generatedProviders, graph.Nodes, graph.References, removeReachableWithoutHelpers); - - var removeHelperRoots = GetHelperRootNames(generatedProviders, graph.Nodes, removeReachableWithoutHelpers, graph.References); - removeRoots.UnionWith(removeHelperRoots); - - var removeReachable = GetReachableTypes(removeRoots, graph.References); - AddBasePreservedReferences(generatedProviders, graph.Nodes, graph.References, removeReachable); - - var removeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, graph.Nodes, publicOnly: false); - removeDeclaredNodes.ExceptWith(removeReachable); - return removeDeclaredNodes; - } - - private static HashSet GetCustomCodeGeneratedTypeRoots(IReadOnlyList providers, HashSet generatedTypeNames) - { - var roots = new HashSet(StringComparer.Ordinal); - foreach (var customCodeView in GetCustomCodeViews(providers)) - { - AddCustomCodeViewGeneratedTypeRoot(roots, customCodeView, generatedTypeNames); - AddCustomCodeViewRoots(roots, customCodeView, generatedTypeNames, publicOnly: false); - } - - return roots; - } - - private static HashSet GetCustomCodePublicGeneratedTypeRoots(IReadOnlyList providers, HashSet generatedTypeNames) - { - var roots = new HashSet(StringComparer.Ordinal); - foreach (var customCodeView in GetCustomCodeViews(providers)) - { - if (!customCodeView.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) - { - continue; - } - - AddCustomCodeViewGeneratedTypeRoot(roots, customCodeView, generatedTypeNames); - AddCustomCodeViewRoots(roots, customCodeView, generatedTypeNames, publicOnly: true); - } - - return roots; - } - - private static IEnumerable GetCustomCodeViews(IReadOnlyList providers) - { - var visited = new HashSet(StringComparer.Ordinal); - var modelFactoryCustomCodeView = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value.CustomCodeView; - if (modelFactoryCustomCodeView != null && visited.Add(GetCustomCodeViewIdentity(modelFactoryCustomCodeView))) - { - yield return modelFactoryCustomCodeView; - } - - foreach (var provider in providers) - { - var customCodeView = provider.CustomCodeView; - if (customCodeView == null || !visited.Add(GetCustomCodeViewIdentity(customCodeView))) - { - continue; - } - - yield return customCodeView; - } - - foreach (var customTypeProvider in CodeModelGenerator.Instance.SourceInputModel.GetCustomizationTypeProviders()) - { - if (visited.Add(GetCustomCodeViewIdentity(customTypeProvider))) - { - yield return customTypeProvider; - } - } - } - - private static string GetCustomCodeViewIdentity(TypeProvider customCodeView) => - customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider - ? namedTypeSymbolProvider.MetadataName - : GetProviderTypeName(customCodeView.Type); - - private static void AddCustomRequestHeaderExtensionsRoot(HashSet roots, IReadOnlyList providers, HashSet nodes) - { - // TODO: Resolve body-level SetDelimited extension calls to PipelineRequestHeadersExtensions so this can be a normal type edge. - if (!HasCustomRequestHeaderExtensionsReference(providers)) - { - return; - } - - AddMatchingNamesWithSimpleNameSuffix(roots, "RequestHeaderExtensions", nodes); - AddMatchingNamesWithSimpleNameSuffix(roots, "RequestHeadersExtensions", nodes); - } - - private static void AddCustomCodeExtensionRoots(HashSet roots, IReadOnlyList providers, HashSet nodes) - { - foreach (var customCodeView in GetCustomCodeViews(providers)) - { - AddMatchingName(roots, $"{GetCustomCodeViewSimpleName(customCodeView)}Extensions", nodes); - } - } - - private static string GetCustomCodeViewSimpleName(TypeProvider customCodeView) => - customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider - ? namedTypeSymbolProvider.MetadataSimpleName - : customCodeView.Type.Name; - - private static void AddCustomCodeViewGeneratedTypeRoot(HashSet roots, TypeProvider customCodeView, HashSet generatedTypeNames) - { - if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) - { - AddMatchingName(roots, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); - return; - } - - AddTypeReference(roots, customCodeView.Type, generatedTypeNames); - } - - private static void AddCustomizationBackedExtensionRoots(HashSet roots, HashSet nodes) - { - foreach (var node in nodes) - { - var simpleName = GetSimpleName(node); - if (!simpleName.EndsWith("Extensions", StringComparison.Ordinal)) - { - continue; - } - - var namespaceName = GetNamespaceName(node); - if (namespaceName == null) - { - continue; - } - - var customTypeName = simpleName.Substring(0, simpleName.Length - "Extensions".Length); - if (CodeModelGenerator.Instance.SourceInputModel.FindForTypeInCustomization(namespaceName, customTypeName) != null) - { - roots.Add(node); - } - } - } - - private static void AddCustomCodeViewRoots(HashSet roots, TypeProvider customCodeView, HashSet generatedTypeNames, bool publicOnly) - { - AddTypeReference(roots, customCodeView.BaseType, generatedTypeNames); - AddProviderBodyDependencyTypes(roots, customCodeView.SignatureDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true); - if (!publicOnly) - { - AddAttributes(roots, customCodeView.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); - AddMatchingName(roots, $"{GetCustomCodeViewSimpleName(customCodeView)}Extensions", generatedTypeNames); - } - - foreach (var implementedType in customCodeView.Implements) - { - AddTypeReference(roots, implementedType, generatedTypeNames); - } - - foreach (var constructor in customCodeView.Constructors) - { - if (publicOnly && !IsPublic(constructor.Signature.Modifiers)) - { - continue; - } - - AddSignatureReferences(roots, constructor.Signature, generatedTypeNames, serializationProviderNamesByType: null, includeAttributes: !publicOnly); - } - - foreach (var method in customCodeView.Methods) - { - if (publicOnly && !IsPublic(method.Signature.Modifiers)) - { - continue; - } - - AddSignatureReferences(roots, method.Signature, generatedTypeNames, serializationProviderNamesByType: null, includeAttributes: !publicOnly); - } - - foreach (var property in customCodeView.Properties) - { - if (publicOnly && !IsPublic(property.Modifiers)) - { - continue; - } - - AddTypeReference(roots, property.Type, generatedTypeNames); - AddTypeReference(roots, property.ExplicitInterface, generatedTypeNames); - if (!publicOnly) - { - AddAttributes(roots, property.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); - } - } - - foreach (var field in customCodeView.Fields) - { - if (publicOnly && !IsPublic(field.Modifiers)) - { - continue; - } - - AddTypeReference(roots, field.Type, generatedTypeNames); - if (!publicOnly) - { - AddAttributes(roots, field.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); - } - } - } - - private static HashSet GetApiBaselineGeneratedTypeRoots(HashSet generatedTypeNames) - { - var roots = new HashSet(StringComparer.Ordinal); - var projectDirectory = CodeModelGenerator.Instance.Configuration.ProjectDirectory; - if (string.IsNullOrEmpty(projectDirectory)) - { - return roots; - } - - var apiDirectory = Path.GetFullPath(Path.Combine(projectDirectory, "..", "api")); - if (!Directory.Exists(apiDirectory)) - { - return roots; - } - - var apiText = string.Join("\n", Directory.GetFiles(apiDirectory, "*.cs", SearchOption.AllDirectories).Select(File.ReadAllText)); - var apiDeclaredTypeNames = GetApiDeclaredTypeNames(apiText); - foreach (var fullName in generatedTypeNames) - { - var simpleName = StripGenericArity(GetSimpleName(fullName)); - var normalizedFullName = StripGenericArity(fullName); - if (!ContainsApiTypeReference(apiText, apiDeclaredTypeNames, normalizedFullName, simpleName)) - { - continue; - } - - roots.Add(fullName); - } - - return roots; - } - - private static HashSet GetApiDeclaredTypeNames(string apiText) - { - var declaredTypeNames = new HashSet(StringComparer.Ordinal); - string? currentNamespace = null; - foreach (var line in apiText.Split('\n')) - { - var namespaceMatch = Regex.Match(line, @"^namespace\s+([\w.]+)\s*\{?\s*$"); - if (namespaceMatch.Success) - { - currentNamespace = namespaceMatch.Groups[1].Value; - continue; - } - - if (currentNamespace == null) - { - continue; - } - - var declarationMatch = Regex.Match(line, @"^ \S.*?\b(class|struct|interface|enum)\s+([A-Za-z_][A-Za-z0-9_]*)(?!\s*<)(?!\w)"); - if (declarationMatch.Success) - { - declaredTypeNames.Add($"{currentNamespace}.{declarationMatch.Groups[2].Value}"); - } - } - - return declaredTypeNames; - } - - private static bool ContainsApiTypeReference(string apiText, HashSet apiDeclaredTypeNames, string fullName, string simpleName) - { - var fullNamePattern = $@"(? GetCustomCodeInternalGeneratedTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) - { - var declarations = new HashSet(StringComparer.Ordinal); - foreach (var customCodeView in GetCustomCodeViews(providers)) - { - if (!customCodeView.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal)) - { - continue; - } - - if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) - { - AddMatchingName(declarations, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); - } - else - { - AddTypeReference(declarations, customCodeView.Type, generatedTypeNames); - } - } - - return declarations; - } - - private static HashSet GetGeneratedPersistableModelProxyTypeNames(IReadOnlyList providers, HashSet generatedTypeNames) - { - var proxyTypes = new HashSet(StringComparer.Ordinal); - foreach (var provider in GetGeneratedProviders(providers)) - { - if (provider.Attributes.Any(static attribute => IsAttributeNamed(attribute, "PersistableModelProxy"))) - { - AddTypeReference(proxyTypes, provider.Type, generatedTypeNames); - } - } - - return proxyTypes; - } - - private static HashSet GetGeneratedInternalTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) - => GetGeneratedTypeDeclarationsByLastContractAccessibility(providers, generatedTypeNames, TypeSignatureModifiers.Internal); - - private static HashSet GetGeneratedPublicTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) - => GetGeneratedTypeDeclarationsByLastContractAccessibility(providers, generatedTypeNames, TypeSignatureModifiers.Public); - - private static HashSet GetGeneratedTypeDeclarationsByLastContractAccessibility( - IReadOnlyList providers, - HashSet generatedTypeNames, - TypeSignatureModifiers accessibility) - { - var declarations = new HashSet(StringComparer.Ordinal); - foreach (var provider in GetGeneratedProviders(providers)) - { - if (provider.LastContractView?.DeclarationModifiers.HasFlag(accessibility) != true) - { - continue; - } - - AddTypeReference(declarations, provider.Type, generatedTypeNames); - } - - return declarations; - } - - private static HashSet GetGeneratedImplementationInternalTypeDeclarations(HashSet generatedInternalDeclarations) - { - var implementationDeclarations = new HashSet(StringComparer.Ordinal); - foreach (var name in generatedInternalDeclarations) - { - if (GetSimpleName(name).StartsWith("Internal", StringComparison.Ordinal)) - { - implementationDeclarations.Add(name); - } - } - - return implementationDeclarations; - } - - private static HashSet GetSimpleNames(HashSet names) - { - var simpleNames = new HashSet(StringComparer.Ordinal); - foreach (var name in names) - { - simpleNames.Add(GetSimpleName(name)); - } - - return simpleNames; - } - - private static ProviderReferenceGraph BuildGraph(IReadOnlyList generatedProviders, bool publicOnly = false) - { - // Each generated provider becomes a node, and provider metadata supplies the edges: - // inheritance, signatures, properties, fields, nested/serialization providers, attributes, - // and selected implementation dependencies. This avoids parsing generated C# just to - // rediscover generated-to-generated references. - var serializationProviderNamesByType = GetSerializationProviderNamesByType(generatedProviders); - IReadOnlyDictionary? serializationReferenceNamesByType = publicOnly ? null : serializationProviderNamesByType; - var nodes = new HashSet(StringComparer.Ordinal); - var references = new Dictionary>(StringComparer.Ordinal); - foreach (var provider in generatedProviders) - { - var providerName = GetProviderTypeName(provider.Type); - if (nodes.Add(providerName)) - { - references.Add(providerName, new HashSet(StringComparer.Ordinal)); - } - } - - foreach (var provider in generatedProviders) - { - var current = GetProviderTypeName(provider.Type); - AddTypeReference(references[current], provider.Type, nodes, serializationReferenceNamesByType); - AddTypeReference(references[current], provider.BaseType, nodes, serializationReferenceNamesByType); - AddTypeReference(references[current], provider.DeclaringTypeProvider?.Type, nodes, serializationReferenceNamesByType); - - if (IsKept(provider.Type, CodeModelGenerator.Instance.NonRootTypes, nodes)) - { - continue; - } - - // Model factory signatures mention many models. The existing Roslyn post-processor - // removes factory methods for unreachable models, so model factory should only - // contribute helper dependencies, not model reachability edges. - if (IsModelFactoryProvider(provider)) - { - continue; - } - - foreach (var implementedType in provider.Implements) - { - AddTypeReference(references[current], implementedType, nodes, serializationReferenceNamesByType); - } - - if (!publicOnly) - { - foreach (var nestedType in provider.NestedTypes) - { - AddTypeReference(references[current], nestedType.Type, nodes, serializationReferenceNamesByType); - } - } - - if (!publicOnly) - { - foreach (var serializationProvider in provider.SerializationProviders) - { - AddTypeReference(references[current], serializationProvider.Type, nodes, serializationReferenceNamesByType); - } - } - - foreach (var property in provider.Properties) - { - if (publicOnly && !IsPublic(property.Modifiers)) - { - continue; - } - - AddTypeReference(references[current], property.Type, nodes, serializationReferenceNamesByType); - AddTypeReference(references[current], property.ExplicitInterface, nodes, serializationReferenceNamesByType); - if (!publicOnly) - { - AddAttributes(references[current], property.Attributes, nodes, serializationReferenceNamesByType, includeArguments: false); - } - } - - foreach (var field in provider.Fields) - { - if (publicOnly && !field.Modifiers.HasFlag(FieldModifiers.Public)) - { - continue; - } - - AddTypeReference(references[current], field.Type, nodes, serializationReferenceNamesByType); - if (!publicOnly) - { - AddAttributes(references[current], field.Attributes, nodes, serializationReferenceNamesByType, includeArguments: false); - } - } - - foreach (var constructor in provider.Constructors) - { - if (publicOnly && !IsPublic(constructor.Signature.Modifiers)) - { - continue; - } - - AddSignatureReferences(references[current], constructor.Signature, nodes, serializationReferenceNamesByType, includeAttributes: !publicOnly, includeAttributeArguments: false); - } - - foreach (var method in provider.Methods) - { - if (method.IsMethodSuppressed()) - { - continue; - } - - if (publicOnly && !IsPublic(method.Signature.Modifiers)) - { - continue; - } - - AddSignatureReferences(references[current], method.Signature, nodes, serializationReferenceNamesByType, includeAttributes: !publicOnly, includeAttributeArguments: false); - if (!publicOnly) - { - AddTypeReference(references[current], GetCollectionDefinitionType(method), nodes, serializationReferenceNamesByType); - } - } - } - - return new ProviderReferenceGraph(nodes, references); - } - - private static Dictionary GetSerializationProviderNamesByType(IReadOnlyList generatedProviders) - { - var namesByType = new Dictionary>(StringComparer.Ordinal); - foreach (var provider in generatedProviders) - { - if (provider.SerializationProviders.Count == 0) - { - continue; - } - - var providerName = GetProviderTypeName(provider.Type); - if (!namesByType.TryGetValue(providerName, out var serializationProviderNames)) - { - serializationProviderNames = new HashSet(StringComparer.Ordinal); - namesByType.Add(providerName, serializationProviderNames); - } - - foreach (var serializationProvider in provider.SerializationProviders) - { - serializationProviderNames.Add(GetProviderTypeName(serializationProvider.Type)); - } - } - - var result = new Dictionary(StringComparer.Ordinal); - foreach (var (providerName, serializationProviderNames) in namesByType) - { - result.Add(providerName, [.. serializationProviderNames]); - } - - return result; - } - - private static CSharpType? GetCollectionDefinitionType(MethodProvider method) - { - var property = method.GetType().GetProperty("CollectionDefinition"); - return property?.GetValue(method) is TypeProvider collectionDefinition - ? collectionDefinition.Type - : null; - } - - private static bool IsPublic(MethodSignatureModifiers modifiers) => modifiers.HasFlag(MethodSignatureModifiers.Public); - private static bool IsPublic(FieldModifiers modifiers) => modifiers.HasFlag(FieldModifiers.Public); - - private static TypeSignatureModifiers MakeInternal(TypeSignatureModifiers modifiers) - => (modifiers & ~(TypeSignatureModifiers.Public | TypeSignatureModifiers.Private | TypeSignatureModifiers.Protected)) | TypeSignatureModifiers.Internal; - - private static TypeSignatureModifiers MakePublic(TypeSignatureModifiers modifiers) - => (modifiers & ~(TypeSignatureModifiers.Internal | TypeSignatureModifiers.Private | TypeSignatureModifiers.Protected)) | TypeSignatureModifiers.Public; - - private static Dictionary> CloneReferences(IReadOnlyDictionary> references) - { - var clone = new Dictionary>(StringComparer.Ordinal); - foreach (var (name, referencedNames) in references) - { - clone.Add(name, new HashSet(referencedNames, StringComparer.Ordinal)); - } - - return clone; - } - - private static void AddDerivedModelReferences( - IReadOnlyList providers, - HashSet nodes, - Dictionary> references, - HashSet publicBaseModels, - HashSet generatedDiscriminatorBaseNames) - { - var modelProviders = new List(); - var discriminatorProviders = new List(); - var discriminatorBaseNames = new HashSet(StringComparer.Ordinal); - foreach (var provider in providers) - { - if (provider is not ModelProvider modelProvider || - !modelProvider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) - { - continue; - } - - modelProviders.Add(modelProvider); - - if (modelProvider.DiscriminatorProperty != null) - { - discriminatorBaseNames.Add(GetProviderTypeName(modelProvider.Type)); - } - - if (!modelProvider.IsUnknownDiscriminatorModel && - (modelProvider.DiscriminatorProperty != null || modelProvider.DiscriminatorValue != null)) - { - discriminatorProviders.Add(modelProvider); - } - } - - discriminatorBaseNames.UnionWith(generatedDiscriminatorBaseNames); - var addedReference = true; - while (addedReference) - { - addedReference = false; - foreach (var provider in discriminatorProviders) - { - var providerName = GetProviderTypeName(provider.Type); - if (!nodes.Contains(providerName)) - { - continue; - } - - if (!publicBaseModels.Contains(providerName)) - { - continue; - } - - foreach (var derivedModel in provider.DerivedModels) - { - if (derivedModel.IsUnknownDiscriminatorModel || - !derivedModel.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) - { - continue; - } - - var before = references[providerName].Count; - AddTypeReference(references[providerName], derivedModel.Type, nodes); - var derivedName = GetProviderTypeName(derivedModel.Type); - if (nodes.Contains(derivedName) && publicBaseModels.Add(derivedName) || references[providerName].Count != before) - { - addedReference = true; - } - } - } - - foreach (var provider in modelProviders) - { - if (provider.IsUnknownDiscriminatorModel || - !provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) - { - continue; - } - - var providerName = GetProviderTypeName(provider.Type); - if (!nodes.Contains(providerName)) - { - continue; - } - - var baseTypeName = provider.BaseType == null ? null : GetProviderTypeName(provider.BaseType); - if (baseTypeName == null || - !discriminatorBaseNames.Contains(baseTypeName) || - !nodes.Contains(baseTypeName) || - !publicBaseModels.Contains(baseTypeName)) - { - continue; - } - - var before = references[baseTypeName].Count; - references[baseTypeName].Add(providerName); - if (publicBaseModels.Add(providerName) || references[baseTypeName].Count != before) - { - addedReference = true; - } - } - } - } - - private static void AddBasePreservedReferences( - IReadOnlyList providers, - HashSet nodes, - IReadOnlyDictionary> references, - HashSet reachableTypes) - { - var basePreservedRoots = new HashSet(StringComparer.Ordinal); - var addedRoot = true; - while (addedRoot) - { - addedRoot = false; - foreach (var provider in GetGeneratedProviders(providers)) - { - var providerName = GetProviderTypeName(provider.Type); - if (!nodes.Contains(providerName) || reachableTypes.Contains(providerName) || basePreservedRoots.Contains(providerName)) - { - continue; - } - - var baseTypeName = provider.BaseType == null ? null : GetProviderTypeName(provider.BaseType); - if (baseTypeName == null || !reachableTypes.Contains(baseTypeName)) - { - continue; - } - - if (basePreservedRoots.Add(providerName)) - { - addedRoot = true; - } - } - - if (addedRoot) - { - reachableTypes.UnionWith(GetReachableTypes(basePreservedRoots, references)); - } - } - } - - private static IReadOnlyList GetGeneratedProviders(IReadOnlyList providers) - { - var generatedProviders = new List(); - foreach (var provider in providers) - { - AddGeneratedProvider(generatedProviders, provider); - } - - return generatedProviders; - } - - private static void AddGeneratedProvider(List generatedProviders, TypeProvider provider) - { - generatedProviders.Add(provider); - foreach (var nestedType in provider.NestedTypes) - { - AddGeneratedProvider(generatedProviders, nestedType); - } - - foreach (var serializationProvider in provider.SerializationProviders) - { - AddGeneratedProvider(generatedProviders, serializationProvider); - } - } - - private static void AddGeneratedBodyReferences(IReadOnlyList providers, ProviderReferenceGraph graph) - { - foreach (var (provider, isSerializationProvider) in GetBodyReferenceProviders(providers)) - { - if (IsModelFactoryProvider(provider) || - !IsGeneratedBodyReferenceCandidate(provider, isSerializationProvider)) - { - continue; - } - - var providerName = GetProviderTypeName(provider.Type); - if (!graph.Nodes.Contains(providerName)) - { - continue; - } - AddProviderBodyDependencyTypes( - graph.References[providerName], - GetNonEnumStructuredBodyReferenceTypes(provider, graph.Nodes), - graph.Nodes); - AddProviderBodyDependencyTypes(graph.References[providerName], provider.BodyDependencyTypes, graph.Nodes); - AddHelperDependencies(graph.References[providerName], provider.HelperDependencyTypes, graph.Nodes, graph.References[providerName]); - } - } - - private static IReadOnlyList GetNonEnumStructuredBodyReferenceTypes(TypeProvider provider, HashSet nodes) - { - var references = new List(); - foreach (var dependency in CollectStructuredBodyReferenceTypes(provider)) - { - if (!IsEnumProviderDependency(dependency, nodes)) - { - references.Add(dependency); - } - } - - return references; - } - - private static IReadOnlyList CollectStructuredBodyReferenceTypes(TypeProvider provider) - { - var references = new HashSet(); - var visited = new HashSet(ReferenceEqualityComparer.Instance); - - foreach (var field in provider.Fields) - { - CollectStructuredBodyReferenceTypes(field.InitializationValue, references, visited); - } - - foreach (var property in provider.Properties) - { - CollectStructuredBodyReferenceTypes(property.Body, references, visited); - } - - foreach (var constructor in provider.Constructors) - { - CollectStructuredBodyReferenceTypes(constructor.BodyExpression, references, visited); - CollectStructuredBodyReferenceTypes(constructor.BodyStatements, references, visited); - } - - foreach (var method in provider.Methods) - { - if (method.IsMethodSuppressed()) - { - continue; - } - - CollectStructuredBodyReferenceTypes(method.BodyExpression, references, visited); - CollectStructuredBodyReferenceTypes(method.BodyStatements, references, visited); - } - - return [.. references]; - } - - private static void CollectStructuredBodyReferenceTypes(object? value, HashSet references, HashSet visited) - { - switch (value) - { - case null: - case string: - case FormattableString: - return; - } - - if (!value.GetType().IsValueType && !visited.Add(value)) - { - return; - } - - switch (value) - { - case CSharpType type: - references.Add(type); - return; - case Type type: - references.Add(type); - return; - case ParameterProvider parameter: - references.Add(parameter.Type); - CollectStructuredBodyReferenceTypes(parameter.DefaultValue, references, visited); - CollectStructuredBodyReferenceTypes(parameter.InitializationValue, references, visited); - return; - case MethodSignatureBase signature: - CollectStructuredBodyReferenceTypes(signature.ReturnType, references, visited); - CollectStructuredBodyReferenceTypes(signature.Parameters, references, visited); - return; - case KeyValuePair positionalArgument: - CollectStructuredBodyReferenceTypes(positionalArgument.Value, references, visited); - return; - case FieldProvider field: - references.Add(field.Type); - CollectStructuredBodyReferenceTypes(field.InitializationValue, references, visited); - return; - } - - if (IsStructuredBodyReferenceObject(value)) - { - foreach (var property in value.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance)) - { - if (property.GetIndexParameters().Length > 0) - { - continue; - } - - CollectStructuredBodyReferenceTypes(property.GetValue(value), references, visited); - } - - return; - } - - if (value is not IEnumerable values) - { - return; - } - - foreach (var item in values) - { - CollectStructuredBodyReferenceTypes(item, references, visited); - } - } - - private static bool IsEnumProviderDependency(CSharpType dependency, HashSet nodes) - { - var providerName = GetProviderTypeName(dependency); - if (!nodes.Contains(providerName)) - { - return false; - } - - foreach (var provider in CodeModelGenerator.Instance.OutputLibrary.TypeProviders) - { - if (provider is EnumProvider && - string.Equals(GetProviderTypeName(provider.Type), providerName, StringComparison.Ordinal)) - { - return true; - } - } - - return false; - } - - private static bool IsStructuredBodyReferenceObject(object value) => - value is ValueExpression || - value is MethodBodyStatement || - value is PropertyBody; - - private static void AddProviderBodyDependencyTypes( - HashSet references, - IReadOnlyList dependencies, - HashSet nodes, - bool includeSimpleNameReferences = false) - { - foreach (var dependency in dependencies) - { - AddProviderBodyDependencyType(references, dependency, nodes, includeSimpleNameReferences); - } - } - - private static void AddProviderBodyDependencyType( - HashSet references, - CSharpType? dependency, - HashSet nodes, - bool includeSimpleNameReferences) - { - if (dependency == null) - { - return; - } - - AddTypeReference(references, dependency, nodes); - if (includeSimpleNameReferences) - { - AddMatchingName(references, dependency.Name, nodes); - } - if (nodes.Contains(GetProviderTypeName(dependency))) - { - AddMatchingName(references, $"{dependency.Name}Extensions", nodes); - } - else if (string.Equals(dependency.Name, "RequestContext", StringComparison.Ordinal)) - { - AddMatchingName(references, "RequestContextExtensions", nodes); - } - - foreach (var argument in dependency.Arguments) - { - AddProviderBodyDependencyType(references, argument, nodes, includeSimpleNameReferences); - } - } - - private static IReadOnlyList<(TypeProvider Provider, bool IsSerializationProvider)> GetBodyReferenceProviders(IReadOnlyList providers) - { - var bodyReferenceProviders = new List<(TypeProvider Provider, bool IsSerializationProvider)>(); - foreach (var provider in providers) - { - bodyReferenceProviders.Add((provider, false)); - foreach (var serializationProvider in provider.SerializationProviders) - { - bodyReferenceProviders.Add((serializationProvider, true)); - } - } - - return bodyReferenceProviders; - } - - private static bool IsGeneratedBodyReferenceCandidate(TypeProvider provider, bool isSerializationProvider) - { - if (provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) - { - return true; - } - - return provider.IsReferenceMapRoot || - isSerializationProvider || - provider.IncludeGeneratedBodyReferences || - provider.HelperDependencyTypes.Count > 0 || - provider.BodyDependencyTypes.Count > 0; - } - - private static HashSet GetRootNames( - IReadOnlyList providers, - HashSet nodes, - HashSet helperRoots, - bool includeModelFactory, - bool includeAdditionalRoots, - bool includeUnionVariantRoots, - bool publicClientRootsOnly) - { - var generator = CodeModelGenerator.Instance; - var roots = new HashSet(StringComparer.Ordinal); - var modelFactoryName = GetProviderTypeName(generator.OutputLibrary.ModelFactory.Value.Type); - - foreach (var provider in providers) - { - var name = GetProviderTypeName(provider.Type); - if (IsReferenceMapRootProvider(provider, publicClientRootsOnly) || - includeAdditionalRoots && IsAdditionalRootProvider(provider, generator.AdditionalRootTypes, nodes) || - includeModelFactory && string.Equals(name, modelFactoryName, StringComparison.Ordinal) || - includeModelFactory && helperRoots.Contains(name)) - { - roots.Add(name); - } - } - - AddLastContractModelFactorySignatureRoots(providers, roots, nodes); - - if (!includeUnionVariantRoots) - { - return roots; - } - - AddUnionVariantRoots(roots, providers, nodes); - - return roots; - } - - private static void AddLastContractModelFactorySignatureRoots(IReadOnlyList providers, HashSet roots, HashSet nodes) - { - foreach (var provider in providers) - { - if (!IsModelFactoryProvider(provider)) - { - continue; - } - - foreach (var method in provider.LastContractView?.Methods ?? []) - { - if (!method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Public) || - IsImplementationOnlyModelFactoryMethod(method)) - { - continue; - } - - AddTypeReference(roots, method.Signature.ReturnType, nodes); - foreach (var parameter in method.Signature.Parameters) - { - AddTypeReference(roots, parameter.Type, nodes); - } - } - } - } - - private static void AddUnionVariantRoots(HashSet roots, IReadOnlyList providers, HashSet nodes) - { - var unionVariantTypesToKeep = CodeModelGenerator.Instance.TypeFactory.UnionVariantTypesToKeep; - foreach (var provider in GetGeneratedProviders(providers)) - { - if (provider is not ModelProvider || - !unionVariantTypesToKeep.Contains(provider.Type.Name) || - string.Equals(provider.Type.Namespace, "TypeSpec.Http", StringComparison.Ordinal)) - { - continue; - } - - AddMatchingName(roots, GetProviderTypeName(provider.Type), nodes); - } - } - - private static bool ShouldUseUnionVariantFallbackRoots() => - !HasApiBaselineDirectory() && - CodeModelGenerator.Instance.SourceInputModel.LastContract == null; - - private static bool IsImplementationOnlyModelFactoryMethod(MethodProvider method) - { - var returnType = method.Signature.ReturnType; - if (returnType == null) - { - return true; - } - - var returnTypeName = GetSimpleName(GetProviderTypeName(returnType)); - return returnTypeName.StartsWith("Paged", StringComparison.Ordinal) || - returnTypeName.EndsWith("Request", StringComparison.Ordinal); - } - - private static void RemoveMethodsFromModelFactory(HashSet namesToRemove) - { - if (namesToRemove.Count == 0) - { - return; - } - - var modelFactory = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value; - _preWriteModelFactory = modelFactory; - _preWriteModelFactoryMethods ??= [.. modelFactory.Methods]; - var methodsToKeep = new List(); - foreach (var method in modelFactory.Methods) - { - if (!namesToRemove.Contains(method.Signature.Name)) - { - methodsToKeep.Add(method); - } - } - - modelFactory.Update(methods: methodsToKeep); - } - - private static HashSet GetPostProcessorDeclaredNodes(IReadOnlyList providers, HashSet nodes, bool publicOnly) - { - var generator = CodeModelGenerator.Instance; - var excludedNames = generator.NonRootTypes; - var declaredNodes = new HashSet(StringComparer.Ordinal); - foreach (var provider in GetGeneratedProviders(providers)) - { - if (IsModelFactoryProvider(provider)) - { - continue; - } - - if (publicOnly && !provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) - { - continue; - } - - var name = GetProviderTypeName(provider.Type); - if (!nodes.Contains(name) || - excludedNames.Contains(name) || - excludedNames.Contains(GetSimpleName(name))) - { - continue; - } - - declaredNodes.Add(name); - } - - return declaredNodes; - } - - private static bool IsKept(CSharpType type, HashSet roots, HashSet nodes) - { - var providerName = GetProviderTypeName(type); - if (roots.Contains(providerName) && nodes.Contains(providerName)) - { - return true; - } - - if (!roots.Contains(type.Name)) - { - return false; - } - - var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); - return simpleNameLookup.TryGetValue(type.Name, out var matches) && - matches.Length == 1 && - string.Equals(matches[0], providerName, StringComparison.Ordinal); - } - - private static bool IsReferenceMapRootProvider(TypeProvider provider, bool publicOnly) => - provider.IsReferenceMapRoot && - (!publicOnly || !HasApiBaselineDirectory() && provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); - - private static bool IsAdditionalRootProvider(TypeProvider provider, HashSet roots, HashSet nodes) - { - if (provider.DeclaringTypeProvider != null || !IsKept(provider.Type, roots, nodes)) - { - return false; - } - - return provider is not ModelProvider && provider is not EnumProvider; - } - - private static bool HasApiBaselineDirectory() - { - var projectDirectory = CodeModelGenerator.Instance.Configuration.ProjectDirectory; - return !string.IsNullOrEmpty(projectDirectory) && - Directory.Exists(Path.GetFullPath(Path.Combine(projectDirectory, "..", "api"))); - } - - private static bool IsModelFactoryProvider(TypeProvider provider) - => provider is ModelFactoryProvider; - - private static HashSet GetHelperRootNames( - IReadOnlyList providers, - HashSet nodes, - HashSet reachableTypes, - IReadOnlyDictionary>? references = null) - { - var roots = new HashSet(StringComparer.Ordinal); - foreach (var provider in GetGeneratedProviders(providers)) - { - var providerName = GetProviderTypeName(provider.Type); - var isModelFactory = IsModelFactoryProvider(provider); - if (!reachableTypes.Contains(providerName) && !isModelFactory) - { - continue; - } - - AddHelperDependencies(roots, provider.HelperDependencyTypes, nodes, references == null ? null : references[providerName]); - - foreach (var property in provider.Properties) - { - AddInitializationHelperRoot(roots, property.Type, nodes); - AddParameterValidationHelperRoot(roots, property.AsParameter, nodes); - } - - foreach (var field in provider.Fields) - { - AddParameterValidationHelperRoot(roots, field.AsParameter, nodes); - } - - foreach (var constructor in provider.Constructors) - { - foreach (var parameter in constructor.Signature.Parameters) - { - AddParameterValidationHelperRoot(roots, parameter, nodes); - } - } - - foreach (var method in provider.Methods) - { - // Only factory methods for reachable models can instantiate collection helpers. - if (isModelFactory && - (method.Signature.ReturnType == null || !reachableTypes.Contains(GetProviderTypeName(method.Signature.ReturnType)))) - { - continue; - } - - foreach (var parameter in method.Signature.Parameters) - { - AddParameterValidationHelperRoot(roots, parameter, nodes); - if (isModelFactory) - { - AddModelFactoryCollectionInitializationHelperRoot(roots, parameter.Type, nodes); - } - } - } - } - - return roots; - } - - private static void AddParameterValidationHelperRoot(HashSet roots, ParameterProvider parameter, HashSet nodes) - { - if (parameter.Validation != ParameterValidationType.None) - { - AddMatchingName(roots, "Argument", nodes); - } - } - - private static void AddHelperDependencies( - HashSet roots, - IReadOnlyList dependencies, - HashSet nodes, - HashSet? referencedNames) - { - foreach (var dependency in dependencies) - { - if (referencedNames == null) - { - AddTypeReference(roots, dependency, nodes); - continue; - } - - var matches = new HashSet(StringComparer.Ordinal); - AddTypeReference(matches, dependency, nodes); - foreach (var match in matches) - { - if (referencedNames.Contains(match)) - { - roots.Add(match); - } - } - } - } - - private static void RemoveUnusedRequestHeaderExtensionsRoot( - HashSet roots, - IReadOnlyDictionary> references, - IReadOnlyList providers) - { - var hasCustomReference = HasCustomRequestHeaderExtensionsReference(providers); - if (hasCustomReference) - { - return; - } - - var unusedRequestHeaderExtensions = new List(); - foreach (var root in roots) - { - if (IsRequestHeadersExtensionsRoot(root) && - !HasExternalReference(root, references)) - { - unusedRequestHeaderExtensions.Add(root); - } - } - - roots.ExceptWith(unusedRequestHeaderExtensions); - } - - private static bool HasExternalReference(string root, IReadOnlyDictionary> references) - { - foreach (var (source, sourceReferences) in references) - { - if (!string.Equals(source, root, StringComparison.Ordinal) && - sourceReferences.Contains(root)) - { - return true; - } - } - - return false; - } - - private static bool IsRequestHeadersExtensionsRoot(string root) => - root.EndsWith(".RequestHeaderExtensions", StringComparison.Ordinal) || - root.EndsWith(".RequestHeadersExtensions", StringComparison.Ordinal); - - private static bool HasCustomRequestHeaderExtensionsReference(IReadOnlyList providers) - { - foreach (var customCodeView in GetCustomCodeViews(providers)) - { - if (customCodeView is NamedTypeSymbolProvider) - { - if (HasRequestHeaderExtensionsDependency(customCodeView.HelperDependencyTypes) || - HasRequestHeaderExtensionsDependency(customCodeView.BodyDependencyTypes) || - HasRequestHeaderExtensionsDependency(customCodeView.SignatureDependencyTypes)) - { - return true; - } - - continue; - } - - if (HasRequestHeaderExtensionsDependency(customCodeView.HelperDependencyTypes) || - HasRequestHeaderExtensionsDependency(customCodeView.BodyDependencyTypes) || - HasRequestHeaderExtensionsMethodDependency(customCodeView.Methods) || - HasRequestHeaderExtensionsPropertyDependency(customCodeView.Properties) || - HasRequestHeaderExtensionsFieldDependency(customCodeView.Fields)) - { - return true; - } - } - - return false; - } - - private static bool HasRequestHeaderExtensionsDependency(IEnumerable dependencies) - { - foreach (var dependency in dependencies) - { - if (IsRequestHeaderExtensionsDependency(dependency)) - { - return true; - } - } - - return false; - } - - private static bool HasRequestHeaderExtensionsMethodDependency(IReadOnlyList methods) - { - foreach (var method in methods) - { - if (IsRequestHeaderExtensionsDependency(method.Signature.ReturnType)) - { - return true; - } - - foreach (var parameter in method.Signature.Parameters) - { - if (IsRequestHeaderExtensionsDependency(parameter.Type)) - { - return true; - } - } - } - - return false; - } - - private static bool HasRequestHeaderExtensionsPropertyDependency(IReadOnlyList properties) - { - foreach (var property in properties) - { - if (IsRequestHeaderExtensionsDependency(property.Type)) - { - return true; - } - } - - return false; - } - - private static bool HasRequestHeaderExtensionsFieldDependency(IReadOnlyList fields) - { - foreach (var field in fields) - { - if (IsRequestHeaderExtensionsDependency(field.Type)) - { - return true; - } - } - - return false; - } - - private static bool IsRequestHeaderExtensionsDependency(string name) - => string.Equals(name, "RequestHeaderExtensions", StringComparison.Ordinal) || - string.Equals(name, "SetDelimited", StringComparison.Ordinal); - - private static bool IsRequestHeaderExtensionsDependency(CSharpType? type) - { - if (type == null) - { - return false; - } - - if (IsRequestHeaderExtensionsDependency(type.Name)) - { - return true; - } - - foreach (var argument in type.Arguments) - { - if (IsRequestHeaderExtensionsDependency(argument)) - { - return true; - } - } - - return false; - } - - private static bool IsSerializationProvider(TypeProvider provider) - { - var relativePath = provider.RelativeFilePath.Replace('\\', '/'); - return relativePath.EndsWith(".Serialization.cs", StringComparison.Ordinal) || - relativePath.EndsWith(".Serialization.Multipart.cs", StringComparison.Ordinal); - } - - private static void AddInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) - { - if (type == null) - { - return; - } - - var initializationType = type.PropertyInitializationType; - if (!string.Equals(initializationType.FullyQualifiedName, type.FullyQualifiedName, StringComparison.Ordinal)) - { - AddMatchingName(roots, initializationType.Name, nodes); - } - - if (type is { IsList: true, IsReadOnlyMemory: false }) - { - AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.ListInitializationType, nodes); - } - - if (type.IsDictionary) - { - AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType, nodes); - } - - foreach (var argument in type.Arguments) - { - AddInitializationHelperRoot(roots, argument, nodes); - } - } - - private static void AddModelFactoryCollectionInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) - { - if (type == null) - { - return; - } - - if (type is { IsList: true, IsReadOnlyMemory: false }) - { - AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.ListInitializationType, nodes); - } - - if (type.IsDictionary) - { - AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType, nodes); - } - - foreach (var argument in type.Arguments) - { - AddModelFactoryCollectionInitializationHelperRoot(roots, argument, nodes); - } - } - - private static void AddMatchingName(HashSet target, string name, HashSet nodes) - { - if (nodes.Contains(name)) - { - target.Add(name); - return; - } - - var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); - if (!simpleNameLookup.TryGetValue(name, out var matches)) - { - return; - } - - foreach (var match in matches) - { - target.Add(match); - } - } - - private static void AddMatchingNamesWithSimpleNameSuffix(HashSet target, string suffix, HashSet nodes) - { - foreach (var node in nodes) - { - if (GetSimpleName(node).EndsWith(suffix, StringComparison.Ordinal)) - { - target.Add(node); - } - } - } - - private static Dictionary BuildSimpleNameLookup(HashSet nodes) - { - var lookup = new Dictionary>(StringComparer.Ordinal); - foreach (var node in nodes) - { - var simpleName = StripGenericArity(GetSimpleName(node)); - if (!lookup.TryGetValue(simpleName, out var matchingNodes)) - { - matchingNodes = []; - lookup.Add(simpleName, matchingNodes); - } - - matchingNodes.Add(node); - } - - var result = new Dictionary(StringComparer.Ordinal); - foreach (var (simpleName, matchingNodes) in lookup) - { - result.Add(simpleName, [.. matchingNodes]); - } - - return result; - } - - private static HashSet GetReachableTypes(HashSet roots, IReadOnlyDictionary> references) - { - return GetReachableTypes(roots, references, expandableNodes: null); - } - - private static HashSet GetReachableTypes( - HashSet roots, - IReadOnlyDictionary> references, - HashSet? expandableNodes) - { - var reachable = new HashSet(StringComparer.Ordinal); - var queue = new Queue(roots); - while (queue.Count > 0) - { - var current = queue.Dequeue(); - if (!reachable.Add(current)) - { - continue; - } - - if (expandableNodes != null && !expandableNodes.Contains(current)) - { - continue; - } - - if (!references.TryGetValue(current, out var children)) - { - continue; - } - - foreach (var child in children) - { - queue.Enqueue(child); - } - } - - return reachable; - } - - private static bool HasPublicApiPredecessor( - string name, - IReadOnlyDictionary> references, - HashSet publicizeReachable, - HashSet generatedImplementationInternalDeclarations) - { - foreach (var (owner, children) in references) - { - if (!publicizeReachable.Contains(owner) || - string.Equals(owner, name, StringComparison.Ordinal) || - generatedImplementationInternalDeclarations.Contains(owner) || - !children.Contains(name)) - { - continue; - } - - return true; - } - - return false; - } - - private static void AddSignatureReferences( - HashSet references, - MethodSignatureBase signature, - HashSet nodes, - IReadOnlyDictionary? serializationProviderNamesByType, - bool includeAttributes = true, - bool includeAttributeArguments = true) - { - AddTypeReference(references, signature.ReturnType, nodes, serializationProviderNamesByType); - if (includeAttributes) - { - AddAttributes(references, signature.Attributes, nodes, serializationProviderNamesByType, includeAttributeArguments); - } - - foreach (var parameter in signature.Parameters) - { - AddTypeReference(references, parameter.Type, nodes, serializationProviderNamesByType); - if (includeAttributes) - { - AddAttributes(references, parameter.Attributes, nodes, serializationProviderNamesByType, includeAttributeArguments); - } - } - - if (signature is MethodSignature methodSignature) - { - AddTypeReference(references, methodSignature.ExplicitInterface, nodes, serializationProviderNamesByType); - if (methodSignature.GenericArguments != null) - { - foreach (var genericArgument in methodSignature.GenericArguments) - { - AddTypeReference(references, genericArgument, nodes, serializationProviderNamesByType); - } - } - - if (methodSignature.GenericParameterConstraints != null) - { - foreach (var constraint in methodSignature.GenericParameterConstraints) - { - AddTypeReference(references, constraint.Type, nodes, serializationProviderNamesByType); - } - } - } - - if (signature is ConstructorSignature constructorSignature) - { - AddTypeReference(references, constructorSignature.Type, nodes, serializationProviderNamesByType); - } - } - - private static void AddAttributes( - HashSet references, - IReadOnlyList attributes, - HashSet nodes, - IReadOnlyDictionary? serializationProviderNamesByType, - bool includeArguments) - { - foreach (var attribute in attributes) - { - AddTypeReference(references, attribute.Type, nodes, serializationProviderNamesByType); - if (!includeArguments) - { - continue; - } - - foreach (var argument in attribute.Arguments) - { - AddAttributeArgumentReference(references, argument, nodes, serializationProviderNamesByType); - } - - foreach (var (_, argument) in attribute.PositionalArguments) - { - AddAttributeArgumentReference(references, argument, nodes, serializationProviderNamesByType); - } - } - } - - private static bool IsAttributeNamed(AttributeStatement attribute, string name) - => string.Equals(attribute.Type.Name, name, StringComparison.Ordinal) || - string.Equals(attribute.Type.Name, $"{name}Attribute", StringComparison.Ordinal); - - private static void AddAttributeArgumentReference( - HashSet references, - ValueExpression argument, - HashSet nodes, - IReadOnlyDictionary? serializationProviderNamesByType) - { - if (argument is TypeOfExpression typeOf) - { - AddTypeReference(references, typeOf.Type, nodes, serializationProviderNamesByType); - AddMatchingName(references, typeOf.Type.Name, nodes); - } - } - - private static void AddTypeReference( - HashSet references, - CSharpType? type, - HashSet nodes, - IReadOnlyDictionary? serializationProviderNamesByType = null) - { - if (type == null) - { - return; - } - - if (type.IsArray) - { - AddTypeReference(references, type.ElementType, nodes, serializationProviderNamesByType); - return; - } - - var providerTypeName = GetProviderTypeName(type); - if (nodes.Contains(providerTypeName)) - { - references.Add(providerTypeName); - if (serializationProviderNamesByType != null && serializationProviderNamesByType.TryGetValue(providerTypeName, out var serializationProviderNames)) - { - foreach (var serializationProviderName in serializationProviderNames) - { - references.Add(serializationProviderName); - } - } - } - - AddTypeReference(references, type.BaseType, nodes, serializationProviderNamesByType); - AddTypeReference(references, type.DeclaringType, nodes, serializationProviderNamesByType); - foreach (var argument in type.Arguments) - { - AddTypeReference(references, argument, nodes, serializationProviderNamesByType); - } - } - - private static string GetSimpleName(string fullyQualifiedName) - { - var lastDot = fullyQualifiedName.LastIndexOf('.'); - return lastDot < 0 ? fullyQualifiedName : fullyQualifiedName.Substring(lastDot + 1); - } - - private static string? GetNamespaceName(string fullyQualifiedName) - { - var lastDot = fullyQualifiedName.LastIndexOf('.'); - return lastDot < 0 ? null : fullyQualifiedName.Substring(0, lastDot); - } - - private static string GetProviderTypeName(CSharpType type) - { - var name = type.Arguments.Count > 0 && !type.Name.Contains('`', StringComparison.Ordinal) - ? $"{type.Name}`{type.Arguments.Count}" - : type.Name; - return string.IsNullOrEmpty(type.Namespace) ? name : $"{type.Namespace}.{name}"; - } - - private static string StripGenericArity(string name) - { - var tick = name.IndexOf('`'); - return tick < 0 ? name : name.Substring(0, tick); - } - - private sealed record ProviderReferenceGraph( - HashSet Nodes, - Dictionary> References); - } -} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/PropertyDescriptionBuilder.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/PropertyDescriptionBuilder.cs index 37bac4c1202..33d0fa5eb32 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/PropertyDescriptionBuilder.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/PropertyDescriptionBuilder.cs @@ -65,7 +65,7 @@ internal static IReadOnlyList GetUnionTypesDescriptions(IReadOn } else { - description = new XmlDocStatement("description", [$"{item:C}"]); + description = new XmlDocStatement("description", [$"{item}"]); } values.Add(description); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs new file mode 100644 index 00000000000..79fdd39ec80 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs @@ -0,0 +1,254 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Reflection; +using Microsoft.TypeSpec.Generator.Expressions; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Statements; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static void AddGeneratedBodyReferences(IReadOnlyList providers, ProviderReferenceGraph graph) + { + foreach (var (provider, isSerializationProvider) in GetBodyReferenceProviders(providers)) + { + if (IsModelFactoryProvider(provider) || + !IsGeneratedBodyReferenceCandidate(provider, isSerializationProvider)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!graph.Nodes.Contains(providerName)) + { + continue; + } + AddProviderBodyDependencyTypes( + graph.References[providerName], + GetNonEnumStructuredBodyReferenceTypes(provider, graph.Nodes), + graph.Nodes); + AddProviderBodyDependencyTypes(graph.References[providerName], provider.BodyDependencyTypes, graph.Nodes); + AddHelperDependencies(graph.References[providerName], provider.HelperDependencyTypes, graph.Nodes, graph.References[providerName]); + } + } + + private static IReadOnlyList GetNonEnumStructuredBodyReferenceTypes(TypeProvider provider, HashSet nodes) + { + var references = new List(); + foreach (var dependency in CollectStructuredBodyReferenceTypes(provider)) + { + if (!IsEnumProviderDependency(dependency, nodes)) + { + references.Add(dependency); + } + } + + return references; + } + + private static IReadOnlyList CollectStructuredBodyReferenceTypes(TypeProvider provider) + { + var references = new HashSet(); + var visited = new HashSet(ReferenceEqualityComparer.Instance); + + foreach (var field in provider.Fields) + { + CollectStructuredBodyReferenceTypes(field.InitializationValue, references, visited); + } + + foreach (var property in provider.Properties) + { + CollectStructuredBodyReferenceTypes(property.Body, references, visited); + } + + foreach (var constructor in provider.Constructors) + { + CollectStructuredBodyReferenceTypes(constructor.BodyExpression, references, visited); + CollectStructuredBodyReferenceTypes(constructor.BodyStatements, references, visited); + } + + foreach (var method in provider.Methods) + { + if (method.IsMethodSuppressed()) + { + continue; + } + + CollectStructuredBodyReferenceTypes(method.BodyExpression, references, visited); + CollectStructuredBodyReferenceTypes(method.BodyStatements, references, visited); + } + + return [.. references]; + } + + private static void CollectStructuredBodyReferenceTypes(object? value, HashSet references, HashSet visited) + { + switch (value) + { + case null: + case string: + case FormattableString: + return; + } + + if (!value.GetType().IsValueType && !visited.Add(value)) + { + return; + } + + switch (value) + { + case CSharpType type: + references.Add(type); + return; + case Type type: + references.Add(type); + return; + case ParameterProvider parameter: + references.Add(parameter.Type); + CollectStructuredBodyReferenceTypes(parameter.DefaultValue, references, visited); + CollectStructuredBodyReferenceTypes(parameter.InitializationValue, references, visited); + return; + case MethodSignatureBase signature: + CollectStructuredBodyReferenceTypes(signature.ReturnType, references, visited); + CollectStructuredBodyReferenceTypes(signature.Parameters, references, visited); + return; + case KeyValuePair positionalArgument: + CollectStructuredBodyReferenceTypes(positionalArgument.Value, references, visited); + return; + case FieldProvider field: + references.Add(field.Type); + CollectStructuredBodyReferenceTypes(field.InitializationValue, references, visited); + return; + } + + if (IsStructuredBodyReferenceObject(value)) + { + foreach (var property in value.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance)) + { + if (property.GetIndexParameters().Length > 0) + { + continue; + } + + CollectStructuredBodyReferenceTypes(property.GetValue(value), references, visited); + } + + return; + } + + if (value is not IEnumerable values) + { + return; + } + + foreach (var item in values) + { + CollectStructuredBodyReferenceTypes(item, references, visited); + } + } + + private static bool IsEnumProviderDependency(CSharpType dependency, HashSet nodes) + { + var providerName = GetProviderTypeName(dependency); + if (!nodes.Contains(providerName)) + { + return false; + } + + foreach (var provider in CodeModelGenerator.Instance.OutputLibrary.TypeProviders) + { + if (provider is EnumProvider && + string.Equals(GetProviderTypeName(provider.Type), providerName, StringComparison.Ordinal)) + { + return true; + } + } + + return false; + } + + private static bool IsStructuredBodyReferenceObject(object value) => + value is ValueExpression || + value is MethodBodyStatement || + value is PropertyBody; + + private static void AddProviderBodyDependencyTypes( + HashSet references, + IReadOnlyList dependencies, + HashSet nodes, + bool includeSimpleNameReferences = false) + { + foreach (var dependency in dependencies) + { + AddProviderBodyDependencyType(references, dependency, nodes, includeSimpleNameReferences); + } + } + + private static void AddProviderBodyDependencyType( + HashSet references, + CSharpType? dependency, + HashSet nodes, + bool includeSimpleNameReferences) + { + if (dependency == null) + { + return; + } + + AddTypeReference(references, dependency, nodes); + if (includeSimpleNameReferences) + { + AddMatchingName(references, dependency.Name, nodes); + } + if (nodes.Contains(GetProviderTypeName(dependency))) + { + AddMatchingName(references, $"{dependency.Name}Extensions", nodes); + } + else if (string.Equals(dependency.Name, "RequestContext", StringComparison.Ordinal)) + { + AddMatchingName(references, "RequestContextExtensions", nodes); + } + + foreach (var argument in dependency.Arguments) + { + AddProviderBodyDependencyType(references, argument, nodes, includeSimpleNameReferences); + } + } + + private static IReadOnlyList<(TypeProvider Provider, bool IsSerializationProvider)> GetBodyReferenceProviders(IReadOnlyList providers) + { + var bodyReferenceProviders = new List<(TypeProvider Provider, bool IsSerializationProvider)>(); + foreach (var provider in providers) + { + bodyReferenceProviders.Add((provider, false)); + foreach (var serializationProvider in provider.SerializationProviders) + { + bodyReferenceProviders.Add((serializationProvider, true)); + } + } + + return bodyReferenceProviders; + } + + private static bool IsGeneratedBodyReferenceCandidate(TypeProvider provider, bool isSerializationProvider) + { + if (provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) + { + return true; + } + + return provider.IsReferenceMapRoot || + isSerializationProvider || + provider.IncludeGeneratedBodyReferences || + provider.HelperDependencyTypes.Count > 0 || + provider.BodyDependencyTypes.Count > 0; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Candidates.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Candidates.cs new file mode 100644 index 00000000000..ff39f833d3d --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Candidates.cs @@ -0,0 +1,199 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static HashSet GetCustomInternalBoundaryNodes( + ProviderReferenceGraph publicGraph, + HashSet customInternalDeclarations) + { + var boundaryNodes = new HashSet(StringComparer.Ordinal); + foreach (var node in publicGraph.Nodes) + { + if (!publicGraph.References.TryGetValue(node, out var references)) + { + continue; + } + + if (references.Overlaps(customInternalDeclarations)) + { + boundaryNodes.Add(node); + } + } + + return boundaryNodes; + } + + private static HashSet GetPublicizeDeclaredNodes( + IReadOnlyList generatedProviders, + HashSet nodes, + HashSet internalizeDeclaredNodes) + { + var publicizeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, nodes, publicOnly: false); + return publicizeDeclaredNodes; + } + + private static HashSet GetPublicApiTraversalNodes( + HashSet internalizeDeclaredNodes, + HashSet publicizeDeclaredNodes, + HashSet generatedInternalDeclarations, + HashSet generatedImplementationInternalDeclarations) + { + var traversalNodes = new HashSet(StringComparer.Ordinal); + foreach (var node in internalizeDeclaredNodes) + { + if (generatedInternalDeclarations.Contains(node) || + generatedImplementationInternalDeclarations.Contains(node)) + { + continue; + } + + traversalNodes.Add(node); + } + + foreach (var node in publicizeDeclaredNodes) + { + if (!generatedImplementationInternalDeclarations.Contains(node)) + { + traversalNodes.Add(node); + } + } + + return traversalNodes; + } + + private static HashSet GetInternalizeCandidates( + HashSet internalizeDeclaredNodes, + HashSet publicizeReachable, + HashSet customInternalDeclarations, + HashSet customInternalBoundaryNodes, + HashSet publicizeRoots, + IReadOnlyDictionary> references) + { + var candidates = new HashSet(StringComparer.Ordinal); + foreach (var node in internalizeDeclaredNodes) + { + if (!publicizeReachable.Contains(node) || + customInternalDeclarations.Contains(node) || + customInternalBoundaryNodes.Contains(node) && !publicizeRoots.Contains(node)) + { + candidates.Add(node); + } + } + + // If a public non-root type exposes something that must remain internal, make the + // exposing type internal too. That avoids generating public APIs with internal types. + var addedCandidate = true; + while (addedCandidate) + { + addedCandidate = false; + foreach (var node in internalizeDeclaredNodes) + { + if (candidates.Contains(node) || + publicizeRoots.Contains(node) || + !references.TryGetValue(node, out var nodeReferences) || + !nodeReferences.Overlaps(candidates)) + { + continue; + } + + candidates.Add(node); + addedCandidate = true; + } + } + + return candidates; + } + + private static HashSet GetPublicizeCandidates( + HashSet publicizeDeclaredNodes, + HashSet publicizeReachable, + HashSet customInternalDeclarations, + HashSet customInternalBoundaryNodes, + HashSet internalizeHelperRoots, + HashSet publicizeRootExclusions, + HashSet generatedInternalDeclarations, + HashSet publicizeRoots, + Dictionary> publicApiReferences, + Dictionary> internalizeReferences, + HashSet generatedImplementationInternalDeclarations) + { + var candidates = new HashSet(StringComparer.Ordinal); + foreach (var node in publicizeDeclaredNodes) + { + if (customInternalDeclarations.Contains(node) || + customInternalBoundaryNodes.Contains(node) || + internalizeHelperRoots.Contains(node) || + publicizeRootExclusions.Contains(node) || + !publicizeReachable.Contains(node)) + { + continue; + } + + if (generatedInternalDeclarations.Contains(node) && + !publicizeRoots.Contains(node) && + !HasPublicApiPredecessor(node, publicApiReferences, publicizeReachable, generatedImplementationInternalDeclarations)) + { + continue; + } + + if (!publicizeRoots.Contains(node) && + !HasPublicApiPredecessor(node, internalizeReferences, publicizeReachable, generatedImplementationInternalDeclarations)) + { + continue; + } + + candidates.Add(node); + } + + return candidates; + } + + private static HashSet GetRemovalCandidates( + IReadOnlyList providers, + IReadOnlyList generatedProviders, + ProviderReferenceGraph graph, + HashSet customRemovalRoots, + HashSet generatedDiscriminatorBaseNames) + { + var removeRoots = GetRootNames( + providers, + graph.Nodes, + helperRoots: [], + includeModelFactory: true, + includeAdditionalRoots: true, + includeUnionVariantRoots: true, + includeModelFactorySignatureRoots: true, + publicClientRootsOnly: false); + + removeRoots.UnionWith(customRemovalRoots); + AddMatchingNamesWithSimpleNameSuffix(removeRoots, "ReferenceType", graph.Nodes); + AddCustomCodeExtensionRoots(removeRoots, generatedProviders, graph.Nodes); + AddCustomizationBackedExtensionRoots(removeRoots, graph.Nodes); + AddCustomRequestHeaderExtensionsRoot(removeRoots, generatedProviders, graph.Nodes); + RemoveUnusedRequestHeaderExtensionsRoot(removeRoots, graph.References, providers); + + var removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); + AddDerivedModelReferences(providers, graph.Nodes, graph.References, removeReachableWithoutHelpers, generatedDiscriminatorBaseNames); + removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); + AddBasePreservedReferences(generatedProviders, graph.Nodes, graph.References, removeReachableWithoutHelpers); + + var removeHelperRoots = GetHelperRootNames(generatedProviders, graph.Nodes, removeReachableWithoutHelpers, graph.References); + removeRoots.UnionWith(removeHelperRoots); + + var removeReachable = GetReachableTypes(removeRoots, graph.References); + AddBasePreservedReferences(generatedProviders, graph.Nodes, graph.References, removeReachable); + + var removeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, graph.Nodes, publicOnly: false); + removeDeclaredNodes.ExceptWith(removeReachable); + return removeDeclaredNodes; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs new file mode 100644 index 00000000000..0c5cc2cfe61 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs @@ -0,0 +1,362 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text.RegularExpressions; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static HashSet GetCustomCodeGeneratedTypeRoots(IReadOnlyList providers, HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + AddCustomCodeViewGeneratedTypeRoot(roots, customCodeView, generatedTypeNames); + AddCustomCodeViewRoots(roots, customCodeView, generatedTypeNames, publicOnly: false); + } + + return roots; + } + + private static HashSet GetCustomCodePublicGeneratedTypeRoots(IReadOnlyList providers, HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + if (!customCodeView.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + AddCustomCodeViewGeneratedTypeRoot(roots, customCodeView, generatedTypeNames); + AddCustomCodeViewRoots(roots, customCodeView, generatedTypeNames, publicOnly: true); + } + + return roots; + } + + private static IEnumerable GetCustomCodeViews(IReadOnlyList providers) + { + var visited = new HashSet(StringComparer.Ordinal); + var modelFactoryCustomCodeView = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value.CustomCodeView; + if (modelFactoryCustomCodeView != null && visited.Add(GetCustomCodeViewIdentity(modelFactoryCustomCodeView))) + { + yield return modelFactoryCustomCodeView; + } + + foreach (var provider in providers) + { + var customCodeView = provider.CustomCodeView; + if (customCodeView == null || !visited.Add(GetCustomCodeViewIdentity(customCodeView))) + { + continue; + } + + yield return customCodeView; + } + + foreach (var customTypeProvider in CodeModelGenerator.Instance.SourceInputModel.GetCustomizationTypeProviders()) + { + if (visited.Add(GetCustomCodeViewIdentity(customTypeProvider))) + { + yield return customTypeProvider; + } + } + } + + private static string GetCustomCodeViewIdentity(TypeProvider customCodeView) => + customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider + ? namedTypeSymbolProvider.MetadataName + : GetProviderTypeName(customCodeView.Type); + + private static void AddCustomRequestHeaderExtensionsRoot(HashSet roots, IReadOnlyList providers, HashSet nodes) + { + // TODO: Resolve body-level SetDelimited extension calls to PipelineRequestHeadersExtensions so this can be a normal type edge. + if (!HasCustomRequestHeaderExtensionsReference(providers)) + { + return; + } + + AddMatchingNamesWithSimpleNameSuffix(roots, "RequestHeaderExtensions", nodes); + AddMatchingNamesWithSimpleNameSuffix(roots, "RequestHeadersExtensions", nodes); + } + + private static void AddCustomCodeExtensionRoots(HashSet roots, IReadOnlyList providers, HashSet nodes) + { + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + AddMatchingName(roots, $"{GetCustomCodeViewSimpleName(customCodeView)}Extensions", nodes); + } + } + + private static string GetCustomCodeViewSimpleName(TypeProvider customCodeView) => + customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider + ? namedTypeSymbolProvider.MetadataSimpleName + : customCodeView.Type.Name; + + private static void AddCustomCodeViewGeneratedTypeRoot(HashSet roots, TypeProvider customCodeView, HashSet generatedTypeNames) + { + if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) + { + AddMatchingName(roots, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); + return; + } + + AddTypeReference(roots, customCodeView.Type, generatedTypeNames); + } + + private static void AddCustomizationBackedExtensionRoots(HashSet roots, HashSet nodes) + { + foreach (var node in nodes) + { + var simpleName = GetSimpleName(node); + if (!simpleName.EndsWith("Extensions", StringComparison.Ordinal)) + { + continue; + } + + var namespaceName = GetNamespaceName(node); + if (namespaceName == null) + { + continue; + } + + var customTypeName = simpleName.Substring(0, simpleName.Length - "Extensions".Length); + if (CodeModelGenerator.Instance.SourceInputModel.FindForTypeInCustomization(namespaceName, customTypeName) != null) + { + roots.Add(node); + } + } + } + + private static void AddCustomCodeViewRoots(HashSet roots, TypeProvider customCodeView, HashSet generatedTypeNames, bool publicOnly) + { + AddTypeReference(roots, customCodeView.BaseType, generatedTypeNames); + AddProviderBodyDependencyTypes(roots, customCodeView.SignatureDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true); + if (!publicOnly) + { + AddAttributes(roots, customCodeView.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); + AddMatchingName(roots, $"{GetCustomCodeViewSimpleName(customCodeView)}Extensions", generatedTypeNames); + } + + foreach (var implementedType in customCodeView.Implements) + { + AddTypeReference(roots, implementedType, generatedTypeNames); + } + + foreach (var constructor in customCodeView.Constructors) + { + if (publicOnly && !IsPublic(constructor.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(roots, constructor.Signature, generatedTypeNames, serializationProviderNamesByType: null, includeAttributes: !publicOnly); + } + + foreach (var method in customCodeView.Methods) + { + if (publicOnly && !IsPublic(method.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(roots, method.Signature, generatedTypeNames, serializationProviderNamesByType: null, includeAttributes: !publicOnly); + } + + foreach (var property in customCodeView.Properties) + { + if (publicOnly && !IsPublic(property.Modifiers)) + { + continue; + } + + AddTypeReference(roots, property.Type, generatedTypeNames); + AddTypeReference(roots, property.ExplicitInterface, generatedTypeNames); + if (!publicOnly) + { + AddAttributes(roots, property.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); + } + } + + foreach (var field in customCodeView.Fields) + { + if (publicOnly && !IsPublic(field.Modifiers)) + { + continue; + } + + AddTypeReference(roots, field.Type, generatedTypeNames); + if (!publicOnly) + { + AddAttributes(roots, field.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); + } + } + } + + private static HashSet GetApiBaselineGeneratedTypeRoots(HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + var projectDirectory = CodeModelGenerator.Instance.Configuration.ProjectDirectory; + if (string.IsNullOrEmpty(projectDirectory)) + { + return roots; + } + + var apiDirectory = Path.GetFullPath(Path.Combine(projectDirectory, "..", "api")); + if (!Directory.Exists(apiDirectory)) + { + return roots; + } + + var apiText = string.Join("\n", Directory.GetFiles(apiDirectory, "*.cs", SearchOption.AllDirectories).Select(File.ReadAllText)); + var apiDeclaredTypeNames = GetApiDeclaredTypeNames(apiText); + foreach (var fullName in generatedTypeNames) + { + var simpleName = StripGenericArity(GetSimpleName(fullName)); + var normalizedFullName = StripGenericArity(fullName); + if (!ContainsApiTypeReference(apiText, apiDeclaredTypeNames, normalizedFullName, simpleName)) + { + continue; + } + + roots.Add(fullName); + } + + return roots; + } + + private static HashSet GetApiDeclaredTypeNames(string apiText) + { + var declaredTypeNames = new HashSet(StringComparer.Ordinal); + string? currentNamespace = null; + foreach (var line in apiText.Split('\n')) + { + var namespaceMatch = Regex.Match(line, @"^namespace\s+([\w.]+)\s*\{?\s*$"); + if (namespaceMatch.Success) + { + currentNamespace = namespaceMatch.Groups[1].Value; + continue; + } + + if (currentNamespace == null) + { + continue; + } + + var declarationMatch = Regex.Match(line, @"^ \S.*?\b(class|struct|interface|enum)\s+([A-Za-z_][A-Za-z0-9_]*)(?!\s*<)(?!\w)"); + if (declarationMatch.Success) + { + declaredTypeNames.Add($"{currentNamespace}.{declarationMatch.Groups[2].Value}"); + } + } + + return declaredTypeNames; + } + + private static bool ContainsApiTypeReference(string apiText, HashSet apiDeclaredTypeNames, string fullName, string simpleName) + { + var fullNamePattern = $@"(? GetCustomCodeInternalGeneratedTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) + { + var declarations = new HashSet(StringComparer.Ordinal); + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + if (!customCodeView.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal)) + { + continue; + } + + if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) + { + AddMatchingName(declarations, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); + } + else + { + AddTypeReference(declarations, customCodeView.Type, generatedTypeNames); + } + } + + return declarations; + } + + private static HashSet GetGeneratedPersistableModelProxyTypeNames(IReadOnlyList providers, HashSet generatedTypeNames) + { + var proxyTypes = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (provider.Attributes.Any(static attribute => IsAttributeNamed(attribute, "PersistableModelProxy"))) + { + AddTypeReference(proxyTypes, provider.Type, generatedTypeNames); + } + } + + return proxyTypes; + } + + private static HashSet GetGeneratedInternalTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) + => GetGeneratedTypeDeclarationsByLastContractAccessibility(providers, generatedTypeNames, TypeSignatureModifiers.Internal); + + private static HashSet GetGeneratedPublicTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) + => GetGeneratedTypeDeclarationsByLastContractAccessibility(providers, generatedTypeNames, TypeSignatureModifiers.Public); + + private static HashSet GetGeneratedTypeDeclarationsByLastContractAccessibility( + IReadOnlyList providers, + HashSet generatedTypeNames, + TypeSignatureModifiers accessibility) + { + var declarations = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (provider.LastContractView?.DeclarationModifiers.HasFlag(accessibility) != true) + { + continue; + } + + AddTypeReference(declarations, provider.Type, generatedTypeNames); + } + + return declarations; + } + + private static HashSet GetGeneratedImplementationInternalTypeDeclarations(HashSet generatedInternalDeclarations) + { + var implementationDeclarations = new HashSet(StringComparer.Ordinal); + foreach (var name in generatedInternalDeclarations) + { + if (GetSimpleName(name).StartsWith("Internal", StringComparison.Ordinal)) + { + implementationDeclarations.Add(name); + } + } + + return implementationDeclarations; + } + + private static HashSet GetSimpleNames(HashSet names) + { + var simpleNames = new HashSet(StringComparer.Ordinal); + foreach (var name in names) + { + simpleNames.Add(GetSimpleName(name)); + } + + return simpleNames; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.GraphBuilder.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.GraphBuilder.cs new file mode 100644 index 00000000000..cf7569f9c9f --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.GraphBuilder.cs @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static ProviderReferenceGraph BuildGraph(IReadOnlyList generatedProviders, bool publicOnly = false) + { + // Each generated provider becomes a node, and provider metadata supplies the edges: + // inheritance, signatures, properties, fields, nested/serialization providers, attributes, + // and selected implementation dependencies. This avoids parsing generated C# just to + // rediscover generated-to-generated references. + var serializationProviderNamesByType = GetSerializationProviderNamesByType(generatedProviders); + IReadOnlyDictionary? serializationReferenceNamesByType = publicOnly ? null : serializationProviderNamesByType; + var nodes = new HashSet(StringComparer.Ordinal); + var references = new Dictionary>(StringComparer.Ordinal); + foreach (var provider in generatedProviders) + { + var providerName = GetProviderTypeName(provider.Type); + if (nodes.Add(providerName)) + { + references.Add(providerName, new HashSet(StringComparer.Ordinal)); + } + } + + foreach (var provider in generatedProviders) + { + var current = GetProviderTypeName(provider.Type); + AddTypeReference(references[current], provider.Type, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], provider.BaseType, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], provider.DeclaringTypeProvider?.Type, nodes, serializationReferenceNamesByType); + + if (!publicOnly && IsKept(provider.Type, CodeModelGenerator.Instance.NonRootTypes, nodes)) + { + continue; + } + + // Model factory signatures mention many models. The existing Roslyn post-processor + // removes factory methods for unreachable models, so model factory should only + // contribute helper dependencies, not model reachability edges. + if (IsModelFactoryProvider(provider)) + { + continue; + } + + foreach (var implementedType in provider.Implements) + { + AddTypeReference(references[current], implementedType, nodes, serializationReferenceNamesByType); + } + + if (!publicOnly) + { + foreach (var nestedType in provider.NestedTypes) + { + AddTypeReference(references[current], nestedType.Type, nodes, serializationReferenceNamesByType); + } + } + + if (!publicOnly) + { + foreach (var serializationProvider in provider.SerializationProviders) + { + AddTypeReference(references[current], serializationProvider.Type, nodes, serializationReferenceNamesByType); + } + } + + foreach (var property in provider.Properties) + { + if (publicOnly && !IsPublic(property.Modifiers)) + { + continue; + } + + AddTypeReference(references[current], property.Type, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], property.ExplicitInterface, nodes, serializationReferenceNamesByType); + if (!publicOnly) + { + AddAttributes(references[current], property.Attributes, nodes, serializationReferenceNamesByType, includeArguments: false); + } + } + + foreach (var field in provider.Fields) + { + if (publicOnly && !field.Modifiers.HasFlag(FieldModifiers.Public)) + { + continue; + } + + AddTypeReference(references[current], field.Type, nodes, serializationReferenceNamesByType); + if (!publicOnly) + { + AddAttributes(references[current], field.Attributes, nodes, serializationReferenceNamesByType, includeArguments: false); + } + } + + foreach (var constructor in provider.Constructors) + { + if (publicOnly && !IsPublic(constructor.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(references[current], constructor.Signature, nodes, serializationReferenceNamesByType, includeAttributes: !publicOnly, includeAttributeArguments: false); + } + + foreach (var method in provider.Methods) + { + if (method.IsMethodSuppressed()) + { + continue; + } + + if (publicOnly && !IsPublic(method.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(references[current], method.Signature, nodes, serializationReferenceNamesByType, includeAttributes: !publicOnly, includeAttributeArguments: false); + if (!publicOnly) + { + AddTypeReference(references[current], GetCollectionDefinitionType(method), nodes, serializationReferenceNamesByType); + } + } + } + + return new ProviderReferenceGraph(nodes, references); + } + + private static Dictionary GetSerializationProviderNamesByType(IReadOnlyList generatedProviders) + { + var namesByType = new Dictionary>(StringComparer.Ordinal); + foreach (var provider in generatedProviders) + { + if (provider.SerializationProviders.Count == 0) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!namesByType.TryGetValue(providerName, out var serializationProviderNames)) + { + serializationProviderNames = new HashSet(StringComparer.Ordinal); + namesByType.Add(providerName, serializationProviderNames); + } + + foreach (var serializationProvider in provider.SerializationProviders) + { + serializationProviderNames.Add(GetProviderTypeName(serializationProvider.Type)); + } + } + + var result = new Dictionary(StringComparer.Ordinal); + foreach (var (providerName, serializationProviderNames) in namesByType) + { + result.Add(providerName, [.. serializationProviderNames]); + } + + return result; + } + + private static CSharpType? GetCollectionDefinitionType(MethodProvider method) + { + var property = method.GetType().GetProperty("CollectionDefinition"); + return property?.GetValue(method) is TypeProvider collectionDefinition + ? collectionDefinition.Type + : null; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs new file mode 100644 index 00000000000..87f9e32f473 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs @@ -0,0 +1,484 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static bool IsKept(CSharpType type, HashSet roots, HashSet nodes) + { + var providerName = GetProviderTypeName(type); + if (roots.Contains(providerName) && nodes.Contains(providerName)) + { + return true; + } + + if (!roots.Contains(type.Name)) + { + return false; + } + + var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); + return simpleNameLookup.TryGetValue(type.Name, out var matches) && + matches.Length == 1 && + string.Equals(matches[0], providerName, StringComparison.Ordinal); + } + + private static bool IsReferenceMapRootProvider(TypeProvider provider, bool publicOnly) => + provider.IsReferenceMapRoot && + (!publicOnly || !HasApiBaselineDirectory() && provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + + private static bool IsAdditionalRootProvider(TypeProvider provider, HashSet roots, HashSet nodes) + { + if (provider.DeclaringTypeProvider != null || !IsKept(provider.Type, roots, nodes)) + { + return false; + } + + return provider is not ModelProvider && provider is not EnumProvider; + } + + private static bool HasApiBaselineDirectory() + { + var projectDirectory = CodeModelGenerator.Instance.Configuration.ProjectDirectory; + return !string.IsNullOrEmpty(projectDirectory) && + Directory.Exists(Path.GetFullPath(Path.Combine(projectDirectory, "..", "api"))); + } + + private static bool IsModelFactoryProvider(TypeProvider provider) + => provider is ModelFactoryProvider; + + private static HashSet GetHelperRootNames( + IReadOnlyList providers, + HashSet nodes, + HashSet reachableTypes, + IReadOnlyDictionary>? references = null) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + var isModelFactory = IsModelFactoryProvider(provider); + if (!reachableTypes.Contains(providerName) && !isModelFactory) + { + continue; + } + + AddHelperDependencies(roots, provider.HelperDependencyTypes, nodes, references == null ? null : references[providerName]); + + foreach (var property in provider.Properties) + { + AddInitializationHelperRoot(roots, property.Type, nodes); + AddParameterValidationHelperRoot(roots, property.AsParameter, nodes); + } + + foreach (var field in provider.Fields) + { + AddParameterValidationHelperRoot(roots, field.AsParameter, nodes); + } + + foreach (var constructor in provider.Constructors) + { + foreach (var parameter in constructor.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + } + } + + foreach (var method in provider.Methods) + { + // Only factory methods for reachable models can instantiate collection helpers. + if (isModelFactory && + (method.Signature.ReturnType == null || !reachableTypes.Contains(GetProviderTypeName(method.Signature.ReturnType)))) + { + continue; + } + + foreach (var parameter in method.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + if (isModelFactory) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, parameter.Type, nodes); + } + } + } + } + + return roots; + } + + private static void AddParameterValidationHelperRoot(HashSet roots, ParameterProvider parameter, HashSet nodes) + { + if (parameter.Validation != ParameterValidationType.None) + { + AddMatchingName(roots, "Argument", nodes); + } + } + + private static void AddHelperDependencies( + HashSet roots, + IReadOnlyList dependencies, + HashSet nodes, + HashSet? referencedNames) + { + foreach (var dependency in dependencies) + { + if (referencedNames == null) + { + AddTypeReference(roots, dependency, nodes); + continue; + } + + var matches = new HashSet(StringComparer.Ordinal); + AddTypeReference(matches, dependency, nodes); + foreach (var match in matches) + { + if (referencedNames.Contains(match)) + { + roots.Add(match); + } + } + } + } + + private static void RemoveUnusedRequestHeaderExtensionsRoot( + HashSet roots, + IReadOnlyDictionary> references, + IReadOnlyList providers) + { + var hasCustomReference = HasCustomRequestHeaderExtensionsReference(providers); + if (hasCustomReference) + { + return; + } + + var unusedRequestHeaderExtensions = new List(); + foreach (var root in roots) + { + if (IsRequestHeadersExtensionsRoot(root) && + !HasExternalReference(root, references)) + { + unusedRequestHeaderExtensions.Add(root); + } + } + + roots.ExceptWith(unusedRequestHeaderExtensions); + } + + private static bool HasExternalReference(string root, IReadOnlyDictionary> references) + { + foreach (var (source, sourceReferences) in references) + { + if (!string.Equals(source, root, StringComparison.Ordinal) && + sourceReferences.Contains(root)) + { + return true; + } + } + + return false; + } + + private static bool IsRequestHeadersExtensionsRoot(string root) => + root.EndsWith(".RequestHeaderExtensions", StringComparison.Ordinal) || + root.EndsWith(".RequestHeadersExtensions", StringComparison.Ordinal); + + private static bool HasCustomRequestHeaderExtensionsReference(IReadOnlyList providers) + { + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + if (customCodeView is NamedTypeSymbolProvider) + { + if (HasRequestHeaderExtensionsDependency(customCodeView.HelperDependencyTypes) || + HasRequestHeaderExtensionsDependency(customCodeView.BodyDependencyTypes) || + HasRequestHeaderExtensionsDependency(customCodeView.SignatureDependencyTypes)) + { + return true; + } + + continue; + } + + if (HasRequestHeaderExtensionsDependency(customCodeView.HelperDependencyTypes) || + HasRequestHeaderExtensionsDependency(customCodeView.BodyDependencyTypes) || + HasRequestHeaderExtensionsMethodDependency(customCodeView.Methods) || + HasRequestHeaderExtensionsPropertyDependency(customCodeView.Properties) || + HasRequestHeaderExtensionsFieldDependency(customCodeView.Fields)) + { + return true; + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsDependency(IEnumerable dependencies) + { + foreach (var dependency in dependencies) + { + if (IsRequestHeaderExtensionsDependency(dependency)) + { + return true; + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsMethodDependency(IReadOnlyList methods) + { + foreach (var method in methods) + { + if (IsRequestHeaderExtensionsDependency(method.Signature.ReturnType)) + { + return true; + } + + foreach (var parameter in method.Signature.Parameters) + { + if (IsRequestHeaderExtensionsDependency(parameter.Type)) + { + return true; + } + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsPropertyDependency(IReadOnlyList properties) + { + foreach (var property in properties) + { + if (IsRequestHeaderExtensionsDependency(property.Type)) + { + return true; + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsFieldDependency(IReadOnlyList fields) + { + foreach (var field in fields) + { + if (IsRequestHeaderExtensionsDependency(field.Type)) + { + return true; + } + } + + return false; + } + + private static bool IsRequestHeaderExtensionsDependency(string name) + => string.Equals(name, "RequestHeaderExtensions", StringComparison.Ordinal) || + string.Equals(name, "SetDelimited", StringComparison.Ordinal); + + private static bool IsRequestHeaderExtensionsDependency(CSharpType? type) + { + if (type == null) + { + return false; + } + + if (IsRequestHeaderExtensionsDependency(type.Name)) + { + return true; + } + + foreach (var argument in type.Arguments) + { + if (IsRequestHeaderExtensionsDependency(argument)) + { + return true; + } + } + + return false; + } + + private static bool IsSerializationProvider(TypeProvider provider) + { + var relativePath = provider.RelativeFilePath.Replace('\\', '/'); + return relativePath.EndsWith(".Serialization.cs", StringComparison.Ordinal) || + relativePath.EndsWith(".Serialization.Multipart.cs", StringComparison.Ordinal); + } + + private static void AddInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + var initializationType = type.PropertyInitializationType; + if (!string.Equals(initializationType.FullyQualifiedName, type.FullyQualifiedName, StringComparison.Ordinal)) + { + AddMatchingName(roots, initializationType.Name, nodes); + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.ListInitializationType, nodes); + } + + if (type.IsDictionary) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType, nodes); + } + + foreach (var argument in type.Arguments) + { + AddInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddModelFactoryCollectionInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.ListInitializationType, nodes); + } + + if (type.IsDictionary) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType, nodes); + } + + foreach (var argument in type.Arguments) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddMatchingName(HashSet target, string name, HashSet nodes) + { + if (nodes.Contains(name)) + { + target.Add(name); + return; + } + + var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); + if (!simpleNameLookup.TryGetValue(name, out var matches)) + { + return; + } + + foreach (var match in matches) + { + target.Add(match); + } + } + + private static void AddMatchingNamesWithSimpleNameSuffix(HashSet target, string suffix, HashSet nodes) + { + foreach (var node in nodes) + { + if (GetSimpleName(node).EndsWith(suffix, StringComparison.Ordinal)) + { + target.Add(node); + } + } + } + + private static Dictionary BuildSimpleNameLookup(HashSet nodes) + { + var lookup = new Dictionary>(StringComparer.Ordinal); + foreach (var node in nodes) + { + var simpleName = StripGenericArity(GetSimpleName(node)); + if (!lookup.TryGetValue(simpleName, out var matchingNodes)) + { + matchingNodes = []; + lookup.Add(simpleName, matchingNodes); + } + + matchingNodes.Add(node); + } + + var result = new Dictionary(StringComparer.Ordinal); + foreach (var (simpleName, matchingNodes) in lookup) + { + result.Add(simpleName, [.. matchingNodes]); + } + + return result; + } + + private static HashSet GetReachableTypes(HashSet roots, IReadOnlyDictionary> references) + { + return GetReachableTypes(roots, references, expandableNodes: null); + } + + private static HashSet GetReachableTypes( + HashSet roots, + IReadOnlyDictionary> references, + HashSet? expandableNodes) + { + var reachable = new HashSet(StringComparer.Ordinal); + var queue = new Queue(roots); + while (queue.Count > 0) + { + var current = queue.Dequeue(); + if (!reachable.Add(current)) + { + continue; + } + + if (expandableNodes != null && !expandableNodes.Contains(current)) + { + continue; + } + + if (!references.TryGetValue(current, out var children)) + { + continue; + } + + foreach (var child in children) + { + queue.Enqueue(child); + } + } + + return reachable; + } + + private static bool HasPublicApiPredecessor( + string name, + IReadOnlyDictionary> references, + HashSet publicizeReachable, + HashSet generatedImplementationInternalDeclarations) + { + foreach (var (owner, children) in references) + { + if (!publicizeReachable.Contains(owner) || + string.Equals(owner, name, StringComparison.Ordinal) || + generatedImplementationInternalDeclarations.Contains(owner) || + !children.Contains(name)) + { + continue; + } + + return true; + } + + return false; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.ReferenceTraversal.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.ReferenceTraversal.cs new file mode 100644 index 00000000000..77327749735 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.ReferenceTraversal.cs @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static Dictionary> CloneReferences(IReadOnlyDictionary> references) + { + var clone = new Dictionary>(StringComparer.Ordinal); + foreach (var (name, referencedNames) in references) + { + clone.Add(name, new HashSet(referencedNames, StringComparer.Ordinal)); + } + + return clone; + } + + private static void AddDerivedModelReferences( + IReadOnlyList providers, + HashSet nodes, + Dictionary> references, + HashSet publicBaseModels, + HashSet generatedDiscriminatorBaseNames) + { + var modelProviders = new List(); + var discriminatorProviders = new List(); + var discriminatorBaseNames = new HashSet(StringComparer.Ordinal); + foreach (var provider in providers) + { + if (provider is not ModelProvider modelProvider || + !modelProvider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + modelProviders.Add(modelProvider); + + if (modelProvider.DiscriminatorProperty != null) + { + discriminatorBaseNames.Add(GetProviderTypeName(modelProvider.Type)); + } + + if (!modelProvider.IsUnknownDiscriminatorModel && + (modelProvider.DiscriminatorProperty != null || modelProvider.DiscriminatorValue != null)) + { + discriminatorProviders.Add(modelProvider); + } + } + + discriminatorBaseNames.UnionWith(generatedDiscriminatorBaseNames); + var addedReference = true; + while (addedReference) + { + addedReference = false; + foreach (var provider in discriminatorProviders) + { + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName)) + { + continue; + } + + if (!publicBaseModels.Contains(providerName)) + { + continue; + } + + foreach (var derivedModel in provider.DerivedModels) + { + if (derivedModel.IsUnknownDiscriminatorModel || + !derivedModel.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var before = references[providerName].Count; + AddTypeReference(references[providerName], derivedModel.Type, nodes); + var derivedName = GetProviderTypeName(derivedModel.Type); + if (nodes.Contains(derivedName) && publicBaseModels.Add(derivedName) || references[providerName].Count != before) + { + addedReference = true; + } + } + } + + foreach (var provider in modelProviders) + { + if (provider.IsUnknownDiscriminatorModel || + !provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName)) + { + continue; + } + + var baseTypeName = provider.BaseType == null ? null : GetProviderTypeName(provider.BaseType); + if (baseTypeName == null || + !discriminatorBaseNames.Contains(baseTypeName) || + !nodes.Contains(baseTypeName) || + !publicBaseModels.Contains(baseTypeName)) + { + continue; + } + + var before = references[baseTypeName].Count; + references[baseTypeName].Add(providerName); + if (publicBaseModels.Add(providerName) || references[baseTypeName].Count != before) + { + addedReference = true; + } + } + } + } + + private static void AddBasePreservedReferences( + IReadOnlyList providers, + HashSet nodes, + IReadOnlyDictionary> references, + HashSet reachableTypes) + { + var basePreservedRoots = new HashSet(StringComparer.Ordinal); + var addedRoot = true; + while (addedRoot) + { + addedRoot = false; + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName) || reachableTypes.Contains(providerName) || basePreservedRoots.Contains(providerName)) + { + continue; + } + + var baseTypeName = provider.BaseType == null ? null : GetProviderTypeName(provider.BaseType); + if (baseTypeName == null || !reachableTypes.Contains(baseTypeName)) + { + continue; + } + + if (basePreservedRoots.Add(providerName)) + { + addedRoot = true; + } + } + + if (addedRoot) + { + reachableTypes.UnionWith(GetReachableTypes(basePreservedRoots, references)); + } + } + } + + private static IReadOnlyList GetGeneratedProviders(IReadOnlyList providers) + { + var generatedProviders = new List(); + foreach (var provider in providers) + { + AddGeneratedProvider(generatedProviders, provider); + } + + return generatedProviders; + } + + private static void AddGeneratedProvider(List generatedProviders, TypeProvider provider) + { + generatedProviders.Add(provider); + foreach (var nestedType in provider.NestedTypes) + { + AddGeneratedProvider(generatedProviders, nestedType); + } + + foreach (var serializationProvider in provider.SerializationProviders) + { + AddGeneratedProvider(generatedProviders, serializationProvider); + } + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.RootSelection.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.RootSelection.cs new file mode 100644 index 00000000000..88d265e54c6 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.RootSelection.cs @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static HashSet GetRootNames( + IReadOnlyList providers, + HashSet nodes, + HashSet helperRoots, + bool includeModelFactory, + bool includeAdditionalRoots, + bool includeUnionVariantRoots, + bool includeModelFactorySignatureRoots, + bool publicClientRootsOnly) + { + var generator = CodeModelGenerator.Instance; + var roots = new HashSet(StringComparer.Ordinal); + var modelFactoryName = GetProviderTypeName(generator.OutputLibrary.ModelFactory.Value.Type); + + foreach (var provider in providers) + { + var name = GetProviderTypeName(provider.Type); + if (IsReferenceMapRootProvider(provider, publicClientRootsOnly) || + includeAdditionalRoots && IsAdditionalRootProvider(provider, generator.AdditionalRootTypes, nodes) || + includeModelFactory && string.Equals(name, modelFactoryName, StringComparison.Ordinal) || + includeModelFactory && helperRoots.Contains(name)) + { + roots.Add(name); + } + } + + if (includeModelFactorySignatureRoots) + { + AddLastContractModelFactorySignatureRoots(providers, roots, nodes); + } + + if (!includeUnionVariantRoots) + { + return roots; + } + + AddUnionVariantRoots(roots, providers, nodes); + + return roots; + } + + private static void AddLastContractModelFactorySignatureRoots(IReadOnlyList providers, HashSet roots, HashSet nodes) + { + foreach (var provider in providers) + { + if (!IsModelFactoryProvider(provider)) + { + continue; + } + + foreach (var method in provider.LastContractView?.Methods ?? []) + { + if (!method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Public) || + IsImplementationOnlyModelFactoryMethod(method)) + { + continue; + } + + AddTypeReference(roots, method.Signature.ReturnType, nodes); + foreach (var parameter in method.Signature.Parameters) + { + AddTypeReference(roots, parameter.Type, nodes); + } + } + } + } + + private static void AddUnionVariantRoots(HashSet roots, IReadOnlyList providers, HashSet nodes) + { + var unionVariantTypesToKeep = CodeModelGenerator.Instance.TypeFactory.UnionVariantTypesToKeep; + foreach (var provider in GetGeneratedProviders(providers)) + { + if (provider is not ModelProvider || + !unionVariantTypesToKeep.Contains(provider.Type.Name) || + string.Equals(provider.Type.Namespace, "TypeSpec.Http", StringComparison.Ordinal)) + { + continue; + } + + AddMatchingName(roots, GetProviderTypeName(provider.Type), nodes); + } + } + + private static bool ShouldUseUnionVariantFallbackRoots() => + !HasApiBaselineDirectory() && + CodeModelGenerator.Instance.SourceInputModel.LastContract == null; + + private static bool IsImplementationOnlyModelFactoryMethod(MethodProvider method) + { + var returnType = method.Signature.ReturnType; + if (returnType == null) + { + return true; + } + + var returnTypeName = GetSimpleName(GetProviderTypeName(returnType)); + return returnTypeName.StartsWith("Paged", StringComparison.Ordinal) || + returnTypeName.EndsWith("Request", StringComparison.Ordinal); + } + + private static void RemoveMethodsFromModelFactory(HashSet namesToRemove) + { + if (namesToRemove.Count == 0) + { + return; + } + + var modelFactory = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value; + _preWriteModelFactory = modelFactory; + _preWriteModelFactoryMethods ??= [.. modelFactory.Methods]; + var methodsToKeep = new List(); + foreach (var method in modelFactory.Methods) + { + if (!namesToRemove.Contains(method.Signature.Name)) + { + methodsToKeep.Add(method); + } + } + + modelFactory.Update(methods: methodsToKeep); + } + + private static HashSet GetPostProcessorDeclaredNodes(IReadOnlyList providers, HashSet nodes, bool publicOnly) + { + var generator = CodeModelGenerator.Instance; + var declaredNodes = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (IsModelFactoryProvider(provider)) + { + continue; + } + + if (publicOnly && !provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var name = GetProviderTypeName(provider.Type); + if (!nodes.Contains(name)) + { + continue; + } + + declaredNodes.Add(name); + } + + return declaredNodes; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.TypeReferenceCollector.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.TypeReferenceCollector.cs new file mode 100644 index 00000000000..e6c4c9fd958 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.TypeReferenceCollector.cs @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using Microsoft.TypeSpec.Generator.Expressions; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Statements; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static void AddSignatureReferences( + HashSet references, + MethodSignatureBase signature, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType, + bool includeAttributes = true, + bool includeAttributeArguments = true) + { + AddTypeReference(references, signature.ReturnType, nodes, serializationProviderNamesByType); + if (includeAttributes) + { + AddAttributes(references, signature.Attributes, nodes, serializationProviderNamesByType, includeAttributeArguments); + } + + foreach (var parameter in signature.Parameters) + { + AddTypeReference(references, parameter.Type, nodes, serializationProviderNamesByType); + if (includeAttributes) + { + AddAttributes(references, parameter.Attributes, nodes, serializationProviderNamesByType, includeAttributeArguments); + } + } + + if (signature is MethodSignature methodSignature) + { + AddTypeReference(references, methodSignature.ExplicitInterface, nodes, serializationProviderNamesByType); + if (methodSignature.GenericArguments != null) + { + foreach (var genericArgument in methodSignature.GenericArguments) + { + AddTypeReference(references, genericArgument, nodes, serializationProviderNamesByType); + } + } + + if (methodSignature.GenericParameterConstraints != null) + { + foreach (var constraint in methodSignature.GenericParameterConstraints) + { + AddTypeReference(references, constraint.Type, nodes, serializationProviderNamesByType); + } + } + } + + if (signature is ConstructorSignature constructorSignature) + { + AddTypeReference(references, constructorSignature.Type, nodes, serializationProviderNamesByType); + } + } + + private static void AddAttributes( + HashSet references, + IReadOnlyList attributes, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType, + bool includeArguments) + { + foreach (var attribute in attributes) + { + AddTypeReference(references, attribute.Type, nodes, serializationProviderNamesByType); + if (!includeArguments) + { + continue; + } + + foreach (var argument in attribute.Arguments) + { + AddAttributeArgumentReference(references, argument, nodes, serializationProviderNamesByType); + } + + foreach (var (_, argument) in attribute.PositionalArguments) + { + AddAttributeArgumentReference(references, argument, nodes, serializationProviderNamesByType); + } + } + } + + private static bool IsAttributeNamed(AttributeStatement attribute, string name) + => string.Equals(attribute.Type.Name, name, StringComparison.Ordinal) || + string.Equals(attribute.Type.Name, $"{name}Attribute", StringComparison.Ordinal); + + private static void AddAttributeArgumentReference( + HashSet references, + ValueExpression argument, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType) + { + if (argument is TypeOfExpression typeOf) + { + AddTypeReference(references, typeOf.Type, nodes, serializationProviderNamesByType); + AddMatchingName(references, typeOf.Type.Name, nodes); + } + } + + private static void AddTypeReference( + HashSet references, + CSharpType? type, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType = null) + { + if (type == null) + { + return; + } + + if (type.IsArray) + { + AddTypeReference(references, type.ElementType, nodes, serializationProviderNamesByType); + return; + } + + var providerTypeName = GetProviderTypeName(type); + if (nodes.Contains(providerTypeName)) + { + references.Add(providerTypeName); + if (serializationProviderNamesByType != null && serializationProviderNamesByType.TryGetValue(providerTypeName, out var serializationProviderNames)) + { + foreach (var serializationProviderName in serializationProviderNames) + { + references.Add(serializationProviderName); + } + } + } + + AddTypeReference(references, type.BaseType, nodes, serializationProviderNamesByType); + AddTypeReference(references, type.DeclaringType, nodes, serializationProviderNamesByType); + foreach (var argument in type.Arguments) + { + AddTypeReference(references, argument, nodes, serializationProviderNamesByType); + } + } + + private static string GetSimpleName(string fullyQualifiedName) + { + var lastDot = fullyQualifiedName.LastIndexOf('.'); + return lastDot < 0 ? fullyQualifiedName : fullyQualifiedName.Substring(lastDot + 1); + } + + private static string? GetNamespaceName(string fullyQualifiedName) + { + var lastDot = fullyQualifiedName.LastIndexOf('.'); + return lastDot < 0 ? null : fullyQualifiedName.Substring(0, lastDot); + } + + private static string GetProviderTypeName(CSharpType type) + { + var name = type.Arguments.Count > 0 && !type.Name.Contains('`', StringComparison.Ordinal) + ? $"{type.Name}`{type.Arguments.Count}" + : type.Name; + return string.IsNullOrEmpty(type.Namespace) ? name : $"{type.Namespace}.{name}"; + } + + private static string StripGenericArity(string name) + { + var tick = name.IndexOf('`'); + return tick < 0 ? name : name.Substring(0, tick); + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs new file mode 100644 index 00000000000..af0428f30ff --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs @@ -0,0 +1,242 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text.RegularExpressions; +using Microsoft.TypeSpec.Generator.Expressions; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Statements; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static ProviderReferenceMapResult? _latestResult; + private static readonly ConditionalWeakTable, Dictionary> _simpleNameLookupCache = new(); + private static TypeProvider? _preWriteModelFactory; + private static MethodProvider[]? _preWriteModelFactoryMethods; + + public static ProviderReferenceMapResult? LatestResult => _latestResult; + public static bool PreWriteAccessibilityApplied { get; private set; } + + public static bool ShouldWriteProvider(TypeProvider provider) => + _latestResult?.RemoveCandidates.Contains(GetProviderTypeName(provider.Type)) != true; + + public static void ResetPreWriteAccessibility() + { + RestorePreWriteModelFactoryMethods(); + _latestResult = null; + PreWriteAccessibilityApplied = false; + } + + public static void ApplyPreWriteAccessibility(IReadOnlyList providers) + { + PreWriteAccessibilityApplied = false; + if (Configuration.UnreferencedTypesHandling == Configuration.UnreferencedTypesHandlingOption.KeepAll) + { + return; + } + + // Accessibility has to be adjusted before files are written. Roslyn can remove files + // later, but it cannot safely change provider declarations or model factory signatures. + var (internalizeCandidates, publicizeCandidates) = GetPreWriteAccessibilityCandidates(providers); + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + if (internalizeCandidates.Contains(providerName)) + { + provider.PreserveXmlDocs(); + provider.Update(modifiers: MakeInternal(provider.DeclarationModifiers)); + } + else if (publicizeCandidates.Contains(providerName)) + { + provider.Update(modifiers: MakePublic(provider.DeclarationModifiers)); + } + } + + RemoveMethodsFromModelFactory(GetSimpleNames(internalizeCandidates)); + PreWriteAccessibilityApplied = true; + } + + public static void RestorePreWriteModelFactoryMethods() + { + if (_preWriteModelFactory == null || _preWriteModelFactoryMethods == null) + { + return; + } + + _preWriteModelFactory.Update(methods: _preWriteModelFactoryMethods); + _preWriteModelFactory = null; + _preWriteModelFactoryMethods = null; + } + + public static void Analyze(IReadOnlyList providers) + { + var generatedProviders = GetGeneratedProviders(providers); + + // Build two graphs from provider metadata: the full implementation graph for removal, + // and the public-surface graph for accessibility decisions. + var graph = BuildGraph(generatedProviders); + var publicGraph = BuildGraph(generatedProviders, publicOnly: true); + + var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(generatedProviders, graph.Nodes); + var apiBaselineGeneratedTypeRoots = GetApiBaselineGeneratedTypeRoots(graph.Nodes); + customPublicRoots.UnionWith(apiBaselineGeneratedTypeRoots); + var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarations(generatedProviders, graph.Nodes); + customPublicRoots.UnionWith(generatedPublicDeclarations); + var customCodeRemovalRoots = GetCustomCodeGeneratedTypeRoots(generatedProviders, graph.Nodes); + var customRemovalRoots = new HashSet(customCodeRemovalRoots, StringComparer.Ordinal); + customRemovalRoots.UnionWith(apiBaselineGeneratedTypeRoots); + customRemovalRoots.UnionWith(generatedPublicDeclarations); + var customInternalDeclarations = GetCustomCodeInternalGeneratedTypeDeclarations(generatedProviders, graph.Nodes); + var generatedInternalDeclarations = GetGeneratedInternalTypeDeclarations(generatedProviders, graph.Nodes); + + // Helper types are rooted after an initial reachability pass so unused infrastructure + // such as change-tracking dictionaries can still be removed when no reachable type needs them. + var generatedDiscriminatorBaseNames = GetGeneratedPersistableModelProxyTypeNames(generatedProviders, publicGraph.Nodes); + var (internalizeCandidates, publicizeCandidates, _) = GetAccessibilityCandidates( + providers, + generatedProviders, + graph, + publicGraph, + customPublicRoots, + customInternalDeclarations, + generatedInternalDeclarations, + generatedDiscriminatorBaseNames); + + // Body-only generated dependencies are needed to avoid deleting helper files, but they do + // not contribute to public API reachability for internalization. + AddGeneratedBodyReferences(providers, graph); + var removeCandidates = GetRemovalCandidates( + providers, + generatedProviders, + graph, + customRemovalRoots, + generatedDiscriminatorBaseNames); + + _latestResult = new ProviderReferenceMapResult( + internalizeCandidates, + publicizeCandidates, + removeCandidates); + RemoveMethodsFromModelFactory(GetSimpleNames(removeCandidates)); + } + + private static (HashSet InternalizeCandidates, HashSet PublicizeCandidates) GetPreWriteAccessibilityCandidates(IReadOnlyList providers) + { + var generatedProviders = GetGeneratedProviders(providers); + var graph = BuildGraph(generatedProviders); + var publicGraph = BuildGraph(generatedProviders, publicOnly: true); + var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(generatedProviders, graph.Nodes); + var apiBaselineGeneratedTypeRoots = GetApiBaselineGeneratedTypeRoots(graph.Nodes); + customPublicRoots.UnionWith(apiBaselineGeneratedTypeRoots); + var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarations(generatedProviders, graph.Nodes); + customPublicRoots.UnionWith(generatedPublicDeclarations); + var customInternalDeclarations = GetCustomCodeInternalGeneratedTypeDeclarations(generatedProviders, graph.Nodes); + var generatedInternalDeclarations = GetGeneratedInternalTypeDeclarations(generatedProviders, graph.Nodes); + var generatedDiscriminatorBaseNames = new HashSet(StringComparer.Ordinal); + + var (internalizeCandidates, publicizeCandidates, _) = GetAccessibilityCandidates( + providers, + generatedProviders, + graph, + publicGraph, + customPublicRoots, + customInternalDeclarations, + generatedInternalDeclarations, + generatedDiscriminatorBaseNames); + + return (internalizeCandidates, publicizeCandidates); + } + + private static (HashSet InternalizeCandidates, HashSet PublicizeCandidates, HashSet InternalizeHelperRoots) GetAccessibilityCandidates( + IReadOnlyList providers, + IReadOnlyList generatedProviders, + ProviderReferenceGraph graph, + ProviderReferenceGraph publicGraph, + HashSet customPublicRoots, + HashSet customInternalDeclarations, + HashSet generatedInternalDeclarations, + HashSet generatedDiscriminatorBaseNames) + { + var internalizeReferences = CloneReferences(publicGraph.References); + + // Start from public client and custom/public API roots. Anything public-reachable can + // stay public unless it crosses a custom/internal boundary. + var internalizeRoots = GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: false, includeAdditionalRoots: true, includeUnionVariantRoots: false, includeModelFactorySignatureRoots: true, publicClientRootsOnly: true); + if (ShouldUseUnionVariantFallbackRoots()) + { + AddUnionVariantRoots(internalizeRoots, providers, graph.Nodes); + } + + var generatedPublicReachable = GetReachableTypes(internalizeRoots, internalizeReferences); + AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, generatedPublicReachable, generatedDiscriminatorBaseNames); + internalizeRoots.UnionWith(customPublicRoots); + var internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); + AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, internalizeReachableWithoutHelpers, generatedDiscriminatorBaseNames); + internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); + var publicizeRoots = new HashSet(internalizeRoots, StringComparer.Ordinal); + var publicApiReferences = CloneReferences(publicGraph.References); + var internalizeHelperRoots = GetHelperRootNames(generatedProviders, graph.Nodes, internalizeReachableWithoutHelpers); + internalizeRoots.UnionWith(internalizeHelperRoots); + var internalizeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, graph.Nodes, publicOnly: true); + var customInternalBoundaryNodes = GetCustomInternalBoundaryNodes(publicGraph, customInternalDeclarations); + var publicizeDeclaredNodes = GetPublicizeDeclaredNodes(generatedProviders, graph.Nodes, internalizeDeclaredNodes); + var generatedImplementationInternalDeclarations = GetGeneratedImplementationInternalTypeDeclarations(generatedInternalDeclarations); + var publicApiTraversalNodes = GetPublicApiTraversalNodes( + internalizeDeclaredNodes, + publicizeDeclaredNodes, + generatedInternalDeclarations, + generatedImplementationInternalDeclarations); + var publicizeReachable = GetReachableTypes(publicizeRoots, internalizeReferences, publicApiTraversalNodes); + var internalizeCandidates = GetInternalizeCandidates( + internalizeDeclaredNodes, + publicizeReachable, + customInternalDeclarations, + customInternalBoundaryNodes, + publicizeRoots, + internalizeReferences); + var publicizeRootExclusions = GetRootNames( + providers, + graph.Nodes, + helperRoots: [], + includeModelFactory: true, + includeAdditionalRoots: true, + includeUnionVariantRoots: true, + includeModelFactorySignatureRoots: false, + publicClientRootsOnly: true); + var publicizeCandidates = GetPublicizeCandidates( + publicizeDeclaredNodes, + publicizeReachable, + customInternalDeclarations, + customInternalBoundaryNodes, + internalizeHelperRoots, + publicizeRootExclusions, + generatedInternalDeclarations, + publicizeRoots, + publicApiReferences, + internalizeReferences, + generatedImplementationInternalDeclarations); + return (internalizeCandidates, publicizeCandidates, internalizeHelperRoots); + } + + private static bool IsPublic(MethodSignatureModifiers modifiers) => modifiers.HasFlag(MethodSignatureModifiers.Public); + private static bool IsPublic(FieldModifiers modifiers) => modifiers.HasFlag(FieldModifiers.Public); + + private static TypeSignatureModifiers MakeInternal(TypeSignatureModifiers modifiers) + => (modifiers & ~(TypeSignatureModifiers.Public | TypeSignatureModifiers.Private | TypeSignatureModifiers.Protected)) | TypeSignatureModifiers.Internal; + + private static TypeSignatureModifiers MakePublic(TypeSignatureModifiers modifiers) + => (modifiers & ~(TypeSignatureModifiers.Internal | TypeSignatureModifiers.Private | TypeSignatureModifiers.Protected)) | TypeSignatureModifiers.Public; + + private sealed record ProviderReferenceGraph( + HashSet Nodes, + Dictionary> References); + } +} From 72604858f791fdd81a7567e54165c52961a91333 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Thu, 2 Jul 2026 14:37:02 +0000 Subject: [PATCH 10/19] fix(http-client-csharp): port reference map infrastructure roots Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/Providers/ClientProvider.cs | 2 +- .../src/Providers/TypeProvider.cs | 2 +- ...iderReferenceMapAnalyzer.BodyReferences.cs | 167 ++++++++++++++++-- ...derReferenceMapAnalyzer.CustomCodeRoots.cs | 1 - .../ProviderReferenceMapAnalyzer.Helpers.cs | 4 +- ...viderReferenceMapAnalyzer.RootSelection.cs | 2 +- 6 files changed, 162 insertions(+), 16 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs index 2c7186b2f25..75fc71973a5 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs @@ -43,7 +43,7 @@ private record ApiVersionFields(FieldProvider Field, PropertyProvider? Correspon private const string ClientSuffix = "Client"; private readonly FormattableString _publicCtorDescription; private readonly InputClient _inputClient; - protected override bool IsReferenceMapRoot => true; + protected override bool IsClientProvider => true; internal InputClient InputClient => _inputClient; private readonly InputAuth? _inputAuth; private readonly ParameterProvider _endpointParameter; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs index 4a073025e25..bbaf3e54d24 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs @@ -311,7 +311,7 @@ private IReadOnlyList ApplyCustomizationFilter(IEnumerable SignatureDependencyTypes => _signatureDependencyTypes ??= BuildSignatureDependencyTypes(); protected internal virtual IReadOnlyList BuildSignatureDependencyTypes() => []; - protected internal virtual bool IsReferenceMapRoot => false; + protected internal virtual bool IsClientProvider => false; protected internal virtual bool IncludeGeneratedBodyReferences => false; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs index 79fdd39ec80..9fb12dd028d 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs @@ -4,6 +4,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Linq; using System.Reflection; using Microsoft.TypeSpec.Generator.Expressions; using Microsoft.TypeSpec.Generator.Primitives; @@ -29,15 +30,168 @@ private static void AddGeneratedBodyReferences(IReadOnlyList provi { continue; } + + AddHelperDependencies(graph.References[providerName], provider.HelperDependencyTypes, graph.Nodes, referencedNames: null); AddProviderBodyDependencyTypes( graph.References[providerName], GetNonEnumStructuredBodyReferenceTypes(provider, graph.Nodes), graph.Nodes); AddProviderBodyDependencyTypes(graph.References[providerName], provider.BodyDependencyTypes, graph.Nodes); - AddHelperDependencies(graph.References[providerName], provider.HelperDependencyTypes, graph.Nodes, graph.References[providerName]); + AddProviderInfrastructureReferences(graph.References[providerName], provider, graph.Nodes); + } + } + + private static void AddProviderInfrastructureReferences(HashSet references, TypeProvider provider, HashSet nodes) + { + AddMatchingName(references, "ProviderConstants", nodes); + AddMatchingName(references, "TypeFormatters", nodes); + + if (provider.SerializationProviders.Count > 0) + { + AddSerializationExtensionReferences(references, provider, nodes); + } + + if (IsSerializationProvider(provider)) + { + AddMatchingName(references, "Optional", nodes); + AddMatchingName(references, "Utf8JsonRequestContent", nodes); + AddMatchingName(references, "ModelSerializationExtensions", nodes); + AddSerializationExtensionReferences(references, provider, nodes); + } + + foreach (var method in provider.Methods) + { + AddMethodInfrastructureReferences(references, method, nodes); } } + private static void AddSerializationExtensionReferences(HashSet references, TypeProvider provider, HashSet nodes) + { + AddSerializationExtensionReferences(references, provider.Type, nodes); + AddSerializationExtensionReferences(references, provider.BaseType, nodes); + foreach (var implementedType in provider.Implements) + { + AddSerializationExtensionReferences(references, implementedType, nodes); + } + + foreach (var property in provider.Properties) + { + AddSerializationExtensionReferences(references, property.Type, nodes); + } + + foreach (var field in provider.Fields) + { + AddSerializationExtensionReferences(references, field.Type, nodes); + } + + foreach (var constructor in provider.Constructors) + { + AddSerializationExtensionReferences(references, constructor.Signature.ReturnType, nodes); + foreach (var parameter in constructor.Signature.Parameters) + { + AddSerializationExtensionReferences(references, parameter.Type, nodes); + } + } + + foreach (var method in provider.Methods) + { + AddSerializationExtensionReferences(references, method.Signature.ReturnType, nodes); + foreach (var parameter in method.Signature.Parameters) + { + AddSerializationExtensionReferences(references, parameter.Type, nodes); + } + } + } + + private static void AddSerializationExtensionReferences(HashSet references, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + AddMatchingName(references, $"{type.Name}Extensions", nodes); + foreach (var argument in type.Arguments) + { + AddSerializationExtensionReferences(references, argument, nodes); + } + } + + private static void AddMethodInfrastructureReferences(HashSet references, MethodProvider method, HashSet nodes) + { + AddReturnTypeInfrastructureReferences(references, method.Signature.ReturnType, nodes); + foreach (var parameter in method.Signature.Parameters) + { + AddRequestContentInfrastructureReferences(references, parameter.Type, nodes); + } + } + + private static void AddReturnTypeInfrastructureReferences(HashSet references, CSharpType? returnType, HashSet nodes) + { + var type = UnwrapTask(returnType); + if (type == null) + { + return; + } + + var typeName = StripGenericArity(type.Name); + if (string.Equals(typeName, "Pageable", StringComparison.Ordinal)) + { + AddMatchingName(references, "PageableWrapper", nodes); + } + else if (string.Equals(typeName, "AsyncPageable", StringComparison.Ordinal)) + { + AddMatchingName(references, "AsyncPageableWrapper", nodes); + } + else if (string.Equals(typeName, "ArmOperation", StringComparison.Ordinal)) + { + AddMatchingNamesWithSimpleNameSuffix(references, "ArmOperation", nodes); + AddMatchingNamesWithSimpleNameSuffix(references, "OperationSource", nodes); + if (type.Arguments.Count > 0) + { + AddMatchingName(references, $"{BuildOperationSourceTypeName(type.Arguments[0])}OperationSource", nodes); + } + } + } + + private static void AddRequestContentInfrastructureReferences(HashSet references, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + if (string.Equals(type.Name, "RequestContent", StringComparison.Ordinal)) + { + AddMatchingName(references, "BinaryContentHelper", nodes); + AddMatchingName(references, "Utf8JsonRequestContent", nodes); + } + + foreach (var argument in type.Arguments) + { + AddRequestContentInfrastructureReferences(references, argument, nodes); + } + } + + private static CSharpType? UnwrapTask(CSharpType? type) + { + var typeName = type == null ? null : StripGenericArity(type.Name); + if ((string.Equals(typeName, "Task", StringComparison.Ordinal) || + string.Equals(typeName, "ValueTask", StringComparison.Ordinal)) && + type?.Arguments.Count > 0) + { + return type.Arguments[0]; + } + + return type; + } + + private static string BuildOperationSourceTypeName(CSharpType type) + { + var argumentNames = string.Join("", type.Arguments.Select(BuildOperationSourceTypeName)); + return $"{type.Name}{(argumentNames.Length > 0 ? "Of" : string.Empty)}{argumentNames}"; + } + private static IReadOnlyList GetNonEnumStructuredBodyReferenceTypes(TypeProvider provider, HashSet nodes) { var references = new List(); @@ -207,14 +361,7 @@ private static void AddProviderBodyDependencyType( { AddMatchingName(references, dependency.Name, nodes); } - if (nodes.Contains(GetProviderTypeName(dependency))) - { - AddMatchingName(references, $"{dependency.Name}Extensions", nodes); - } - else if (string.Equals(dependency.Name, "RequestContext", StringComparison.Ordinal)) - { - AddMatchingName(references, "RequestContextExtensions", nodes); - } + AddMatchingName(references, $"{dependency.Name}Extensions", nodes); foreach (var argument in dependency.Arguments) { @@ -244,7 +391,7 @@ private static bool IsGeneratedBodyReferenceCandidate(TypeProvider provider, boo return true; } - return provider.IsReferenceMapRoot || + return provider.IsClientProvider || isSerializationProvider || provider.IncludeGeneratedBodyReferences || provider.HelperDependencyTypes.Count > 0 || diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs index 0c5cc2cfe61..a441d979651 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs @@ -78,7 +78,6 @@ customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider private static void AddCustomRequestHeaderExtensionsRoot(HashSet roots, IReadOnlyList providers, HashSet nodes) { - // TODO: Resolve body-level SetDelimited extension calls to PipelineRequestHeadersExtensions so this can be a normal type edge. if (!HasCustomRequestHeaderExtensionsReference(providers)) { return; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs index 87f9e32f473..afd9e777b91 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs @@ -30,8 +30,8 @@ private static bool IsKept(CSharpType type, HashSet roots, HashSet - provider.IsReferenceMapRoot && + private static bool IsClientProviderRoot(TypeProvider provider, bool publicOnly) => + provider.IsClientProvider && (!publicOnly || !HasApiBaselineDirectory() && provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); private static bool IsAdditionalRootProvider(TypeProvider provider, HashSet roots, HashSet nodes) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.RootSelection.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.RootSelection.cs index 88d265e54c6..26dea9955f8 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.RootSelection.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.RootSelection.cs @@ -28,7 +28,7 @@ private static HashSet GetRootNames( foreach (var provider in providers) { var name = GetProviderTypeName(provider.Type); - if (IsReferenceMapRootProvider(provider, publicClientRootsOnly) || + if (IsClientProviderRoot(provider, publicClientRootsOnly) || includeAdditionalRoots && IsAdditionalRootProvider(provider, generator.AdditionalRootTypes, nodes) || includeModelFactory && string.Equals(name, modelFactoryName, StringComparison.Ordinal) || includeModelFactory && helperRoots.Contains(name)) From e245cff7476fe92907da369b43c37d224112de34 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Fri, 3 Jul 2026 01:17:22 +0000 Subject: [PATCH 11/19] chore: remove csharp generator changelog entry Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../mtg-hybrid-reference-map-2026-07-02-14-20-08.md | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 .chronus/changes/mtg-hybrid-reference-map-2026-07-02-14-20-08.md diff --git a/.chronus/changes/mtg-hybrid-reference-map-2026-07-02-14-20-08.md b/.chronus/changes/mtg-hybrid-reference-map-2026-07-02-14-20-08.md deleted file mode 100644 index 533726f69bf..00000000000 --- a/.chronus/changes/mtg-hybrid-reference-map-2026-07-02-14-20-08.md +++ /dev/null @@ -1,7 +0,0 @@ ---- -changeKind: fix -packages: - - "@typespec/http-client-csharp" ---- - -Improve generated C# reference-map analysis so provider accessibility and XML documentation stay consistent. From b32f75530b828930be95896c32b909b23e15856c Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Fri, 3 Jul 2026 03:20:00 +0000 Subject: [PATCH 12/19] fix: address csharp reference map ci failures Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../emitter/test/Unit/utils.test.ts | 9 ++- .../src/Providers/RestClientProvider.cs | 11 ++- ...ProviderReferenceMapAnalyzer.Candidates.cs | 81 ++++++++++++++++++- .../ProviderReferenceMapAnalyzer.Helpers.cs | 14 +++- .../ProviderReferenceMapAnalyzer.cs | 1 + .../src/Statements/XmlDocStatement.cs | 2 + .../TestGetUnionTypesDescriptions.cs | 6 +- .../ProviderReferenceMapAnalyzerTests.cs | 52 ++++++++++++ .../test/TestHelpers/TestOutputLibrary.cs | 5 ++ .../src/Generated/Models/Thing.cs | 6 +- 10 files changed, 170 insertions(+), 17 deletions(-) create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs diff --git a/packages/http-client-csharp/emitter/test/Unit/utils.test.ts b/packages/http-client-csharp/emitter/test/Unit/utils.test.ts index f5d42ff505f..31a90d06d24 100644 --- a/packages/http-client-csharp/emitter/test/Unit/utils.test.ts +++ b/packages/http-client-csharp/emitter/test/Unit/utils.test.ts @@ -12,14 +12,15 @@ import { typeSpecCompile, } from "./utils/test-util.js"; +vi.mock("child_process", () => ({ + spawn: vi.fn(), +})); + describe("execCSharpGenerator tests", () => { let spawnMock: any; let sdkContext: CSharpEmitterContext; beforeEach(async () => { - vi.restoreAllMocks(); - vi.mock("child_process", () => ({ - spawn: vi.fn(), - })); + vi.clearAllMocks(); const runner = await createEmitterTestHost(); const program = await typeSpecCompile(``, runner); const context = createEmitterContext(program); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs index ba6967e7d22..0e07cf120b2 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs @@ -85,7 +85,7 @@ protected override IReadOnlyList BuildHelperDependencyTypes() { foreach (var parameter in serviceMethod.Operation.Parameters) { - if (IsGeneratedContentTypeMethodParameter(parameter) || + if (IsContentTypeParameter(parameter) || parameter is not InputHeaderParameter and not InputQueryParameter) { continue; @@ -1172,7 +1172,7 @@ internal static List GetMethodParameters( // when one was already published. UpdateParameterNameWithBackCompat(inputParam, inputParam.Name, client.BackCompatProvider, serviceMethod); - ParameterProvider? parameter = IsGeneratedContentTypeMethodParameter(inputParam) && + ParameterProvider? parameter = IsContentTypeParameter(inputParam) && methodType is ScmMethodKind.Protocol or ScmMethodKind.CreateRequest ? CreateContentTypeParameter(inputParam) : ScmCodeModelGenerator.Instance.TypeFactory.CreateParameter(inputParam)?.ToPublicInputParameter(); @@ -1216,7 +1216,7 @@ methodType is ScmMethodKind.Protocol or ScmMethodKind.CreateRequest break; case ParameterLocation.Query: case ParameterLocation.Header: - if (IsGeneratedContentTypeMethodParameter(inputParam) + if (IsContentTypeParameter(inputParam) && !HasContentTypeBeforeBodyInLastContract(serviceMethod.Name, client.BackCompatProvider)) { sortedParams.Add(contentType++, parameter); @@ -1285,6 +1285,11 @@ private static bool HasLiteralContentTypeHeader(InputOperation operation) return false; } + private static bool IsContentTypeParameter(InputParameter parameter) => + parameter is InputHeaderParameter { IsContentType: true } || + parameter is InputMethodParameter { Location: InputRequestLocation.Header } && + string.Equals(parameter.SerializedName, "Content-Type", StringComparison.OrdinalIgnoreCase); + private static bool IsGeneratedContentTypeMethodParameter(InputParameter parameter) => parameter is InputMethodParameter { Location: InputRequestLocation.Header } && string.Equals(parameter.SerializedName, "Content-Type", StringComparison.OrdinalIgnoreCase); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Candidates.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Candidates.cs index ff39f833d3d..da3b9d05cac 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Candidates.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Candidates.cs @@ -75,14 +75,19 @@ private static HashSet GetInternalizeCandidates( HashSet customInternalDeclarations, HashSet customInternalBoundaryNodes, HashSet publicizeRoots, + HashSet nodes, IReadOnlyDictionary> references) { var candidates = new HashSet(StringComparer.Ordinal); foreach (var node in internalizeDeclaredNodes) { + var isNonRootKept = IsKeptName(node, CodeModelGenerator.Instance.NonRootTypes, nodes); if (!publicizeReachable.Contains(node) || customInternalDeclarations.Contains(node) || - customInternalBoundaryNodes.Contains(node) && !publicizeRoots.Contains(node)) + customInternalBoundaryNodes.Contains(node) && (!publicizeRoots.Contains(node) || isNonRootKept) || + isNonRootKept && + references.TryGetValue(node, out var nodeReferences) && + nodeReferences.Overlaps(customInternalDeclarations)) { candidates.Add(node); } @@ -96,8 +101,9 @@ private static HashSet GetInternalizeCandidates( addedCandidate = false; foreach (var node in internalizeDeclaredNodes) { + var isNonRootKept = IsKeptName(node, CodeModelGenerator.Instance.NonRootTypes, nodes); if (candidates.Contains(node) || - publicizeRoots.Contains(node) || + publicizeRoots.Contains(node) && !isNonRootKept || !references.TryGetValue(node, out var nodeReferences) || !nodeReferences.Overlaps(candidates)) { @@ -109,6 +115,10 @@ private static HashSet GetInternalizeCandidates( } } + // Non-root keep entries preserve the declared type/file, but deliberately do not + // root their dependencies. They can still be internalized when needed to avoid + // exposing custom/internal types through a public surface. + RemoveKeptNonRootNames(candidates, nodes, references, customInternalDeclarations, customInternalBoundaryNodes); return candidates; } @@ -175,6 +185,7 @@ private static HashSet GetRemovalCandidates( removeRoots.UnionWith(customRemovalRoots); AddMatchingNamesWithSimpleNameSuffix(removeRoots, "ReferenceType", graph.Nodes); + AddKeptNonRootNames(removeRoots, graph.Nodes); AddCustomCodeExtensionRoots(removeRoots, generatedProviders, graph.Nodes); AddCustomizationBackedExtensionRoots(removeRoots, graph.Nodes); AddCustomRequestHeaderExtensionsRoot(removeRoots, generatedProviders, graph.Nodes); @@ -195,5 +206,71 @@ private static HashSet GetRemovalCandidates( removeDeclaredNodes.ExceptWith(removeReachable); return removeDeclaredNodes; } + + private static void AddKeptNonRootNames(HashSet roots, HashSet nodes) + { + var nonRootTypes = CodeModelGenerator.Instance.NonRootTypes; + foreach (var node in nodes) + { + if (IsKeptName(node, nonRootTypes, nodes)) + { + roots.Add(node); + } + } + } + + private static void RemoveKeptNonRootNames( + HashSet candidates, + HashSet nodes, + IReadOnlyDictionary> references, + HashSet customInternalDeclarations, + HashSet customInternalBoundaryNodes) + { + var nonRootTypes = CodeModelGenerator.Instance.NonRootTypes; + foreach (var node in nodes) + { + if (customInternalDeclarations.Contains(node) || + customInternalBoundaryNodes.Contains(node)) + { + continue; + } + + if (IsKeptName(node, nonRootTypes, nodes) && + !HasCandidateReference(node, candidates, references)) + { + candidates.Remove(node); + } + } + } + + private static bool HasCandidateReference( + string node, + HashSet candidates, + IReadOnlyDictionary> references) + { + if (references.TryGetValue(node, out var nodeReferences)) + { + foreach (var reference in nodeReferences) + { + if (!string.Equals(reference, node, StringComparison.Ordinal) && + candidates.Contains(reference)) + { + return true; + } + } + } + + foreach (var (source, sourceReferences) in references) + { + if (!string.Equals(source, node, StringComparison.Ordinal) && + candidates.Contains(source) && + sourceReferences.Contains(node)) + { + return true; + } + } + + return false; + } } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs index afd9e777b91..9dd16e07a52 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs @@ -14,18 +14,28 @@ internal static partial class ProviderReferenceMapAnalyzer private static bool IsKept(CSharpType type, HashSet roots, HashSet nodes) { var providerName = GetProviderTypeName(type); + return IsKeptName(providerName, type.Name, roots, nodes); + } + + private static bool IsKeptName(string providerName, HashSet roots, HashSet nodes) + { + return IsKeptName(providerName, StripGenericArity(GetSimpleName(providerName)), roots, nodes); + } + + private static bool IsKeptName(string providerName, string simpleName, HashSet roots, HashSet nodes) + { if (roots.Contains(providerName) && nodes.Contains(providerName)) { return true; } - if (!roots.Contains(type.Name)) + if (!roots.Contains(simpleName)) { return false; } var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); - return simpleNameLookup.TryGetValue(type.Name, out var matches) && + return simpleNameLookup.TryGetValue(simpleName, out var matches) && matches.Length == 1 && string.Equals(matches[0], providerName, StringComparison.Ordinal); } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs index 0e6506d417c..c89186bc345 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs @@ -201,6 +201,7 @@ private static (HashSet InternalizeCandidates, HashSet Publicize customInternalDeclarations, customInternalBoundaryNodes, publicizeRoots, + graph.Nodes, internalizeReferences); var publicizeRootExclusions = GetRootNames( providers, diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Statements/XmlDocStatement.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Statements/XmlDocStatement.cs index cfca3c2db7b..110f92af48e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Statements/XmlDocStatement.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Statements/XmlDocStatement.cs @@ -217,6 +217,8 @@ public static string EscapeLine(string s) "", "", "", + "", + "", "", "", " -/// . -/// . -/// where TKey is of type , where TValue is of type . +/// bool. +/// int. +/// global::System.Collections.Generic.IDictionary{string,int}. /// 21. /// "test". /// True. diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs new file mode 100644 index 00000000000..be99b9ae2f2 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Tests.TestHelpers; +using NUnit.Framework; + +namespace Microsoft.TypeSpec.Generator.Tests.ReferenceMap +{ + public class ProviderReferenceMapAnalyzerTests + { + [SetUp] + public void SetUp() + { + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); + } + + [TearDown] + public void TearDown() + { + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); + } + + [Test] + public void NonRootKeptTypesKeepTheirAccessibility() + { + var context = new TestTypeProvider("SampleContext", TypeSignatureModifiers.Public); + MockHelpers.LoadMockGenerator(createOutputLibrary: () => new TestOutputLibrary(context)); + CodeModelGenerator.Instance.AddTypeToKeep(context.Type.FullyQualifiedName, isRoot: false); + + ProviderReferenceMapAnalyzer.ApplyPreWriteAccessibility([context]); + + Assert.IsTrue(context.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + Assert.IsFalse(context.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal)); + } + + [Test] + public void NonRootKeptTypesAreWrittenWithoutRootingOtherTypes() + { + var context = new TestTypeProvider("SampleContext", TypeSignatureModifiers.Public); + var unusedModel = new TestTypeProvider("UnusedModel", TypeSignatureModifiers.Public); + MockHelpers.LoadMockGenerator(createOutputLibrary: () => new TestOutputLibrary(context, unusedModel)); + CodeModelGenerator.Instance.AddTypeToKeep(context.Type.FullyQualifiedName, isRoot: false); + + ProviderReferenceMapAnalyzer.Analyze([context, unusedModel]); + + Assert.IsTrue(ProviderReferenceMapAnalyzer.ShouldWriteProvider(context)); + Assert.IsFalse(ProviderReferenceMapAnalyzer.ShouldWriteProvider(unusedModel)); + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/TestHelpers/TestOutputLibrary.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/TestHelpers/TestOutputLibrary.cs index 00119db05dd..7490c2f41e5 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/TestHelpers/TestOutputLibrary.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/TestHelpers/TestOutputLibrary.cs @@ -14,6 +14,11 @@ public TestOutputLibrary(TypeProvider typeProvider) _types = [typeProvider]; } + public TestOutputLibrary(params TypeProvider[] typeProviders) + { + _types = typeProviders; + } + protected override TypeProvider[] BuildTypeProviders() => _types; } } diff --git a/packages/http-client-csharp/generator/TestProjects/Local/Sample-TypeSpec/src/Generated/Models/Thing.cs b/packages/http-client-csharp/generator/TestProjects/Local/Sample-TypeSpec/src/Generated/Models/Thing.cs index 96618c816b8..77af28e9169 100644 --- a/packages/http-client-csharp/generator/TestProjects/Local/Sample-TypeSpec/src/Generated/Models/Thing.cs +++ b/packages/http-client-csharp/generator/TestProjects/Local/Sample-TypeSpec/src/Generated/Models/Thing.cs @@ -100,13 +100,13 @@ internal Thing(string rename, BinaryData requiredUnion, string requiredLiteralSt /// Supported types: /// /// - /// . + /// string. /// /// - /// where T is of type . + /// global::System.Collections.Generic.IList{string}. /// /// - /// . + /// int. /// /// /// From c3be68b235f974e52569c4a835ed438509f7f7ec Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Fri, 3 Jul 2026 05:50:49 +0000 Subject: [PATCH 13/19] fix(http-client-csharp): preserve DPG reference map parity Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/Providers/RestClientProvider.cs | 53 ++++++++++++-- .../RestClientProviderTests.cs | 73 +++++++++++++++++++ .../TestClient.cs | 20 +++++ ...iderReferenceMapAnalyzer.BodyReferences.cs | 26 +------ ...derReferenceMapAnalyzer.CustomCodeRoots.cs | 5 ++ .../ProviderReferenceMapAnalyzer.Helpers.cs | 20 ++++- .../ProviderReferenceMapAnalyzer.cs | 4 +- .../ProviderReferenceMapAnalyzerTests.cs | 41 +++++++++++ 8 files changed, 208 insertions(+), 34 deletions(-) create mode 100644 packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeOrderPreservedFromLastContractViewWithNamedBody/TestClient.cs diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs index 0e07cf120b2..5995ac3d5e5 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs @@ -80,7 +80,10 @@ protected override FieldProvider[] BuildFields() protected override IReadOnlyList BuildHelperDependencyTypes() { - var dependencies = new List { new ClientUriBuilderDefinition().Type }; + var uriBuilderType = ScmCodeModelGenerator.Instance.TypeFactory.HttpRequestApi.ToExpression().UriBuilderType; + var dependencies = uriBuilderType == typeof(ClientUriBuilderDefinition) + ? new List { new ClientUriBuilderDefinition().Type } + : []; foreach (var serviceMethod in _inputClient.Methods) { foreach (var parameter in serviceMethod.Operation.Parameters) @@ -1217,7 +1220,7 @@ methodType is ScmMethodKind.Protocol or ScmMethodKind.CreateRequest case ParameterLocation.Query: case ParameterLocation.Header: if (IsContentTypeParameter(inputParam) - && !HasContentTypeBeforeBodyInLastContract(serviceMethod.Name, client.BackCompatProvider)) + && !ShouldPreserveContentTypeBeforeBody(methodType, serviceMethod, client.BackCompatProvider)) { sortedParams.Add(contentType++, parameter); } @@ -1259,6 +1262,37 @@ methodType is ScmMethodKind.Protocol or ScmMethodKind.CreateRequest return [.. sortedParams.Values]; } + private static bool ShouldPreserveContentTypeBeforeBody( + ScmMethodKind methodType, + InputServiceMethod serviceMethod, + TypeProvider backCompatProvider) + { + if (HasContentTypeBeforeBodyInLastContract(serviceMethod.Name, backCompatProvider)) + { + return true; + } + + // The baseline contract used for back-compat may come from a released package and can + // lag the generated sources in the repo. For generated convenience methods that expose a + // domain-named body parameter, keep the historic contentType-before-body ordering to avoid + // repo regen churn while protocol methods continue using the normalized "content" body. + return methodType is ScmMethodKind.Convenience && HasNamedBodyParameter(serviceMethod); + } + + private static bool HasNamedBodyParameter(InputServiceMethod serviceMethod) + { + foreach (var parameter in serviceMethod.Parameters) + { + if (parameter.Location == InputRequestLocation.Body && + !string.Equals(parameter.Name, "content", StringComparison.OrdinalIgnoreCase)) + { + return true; + } + } + + return false; + } + private static ParameterProvider CreateContentTypeParameter(InputParameter inputParam) { var type = new CSharpType(typeof(string), isNullable: !inputParam.IsRequired); @@ -1296,13 +1330,12 @@ private static bool IsGeneratedContentTypeMethodParameter(InputParameter paramet /// /// Checks if the last contract view contains a method matching the given name where - /// a "contentType" parameter appears before the body ("content") parameter. + /// a "contentType" parameter appears before the body parameter. /// If so, we should preserve that ordering for backward compatibility. /// private static bool HasContentTypeBeforeBodyInLastContract(string methodName, TypeProvider backCompatProvider) { const string contentTypeParamName = "contentType"; - const string contentParamName = "content"; var lastContractMethods = backCompatProvider.LastContractView?.Methods; if (lastContractMethods == null || lastContractMethods.Count == 0) @@ -1330,7 +1363,7 @@ private static bool HasContentTypeBeforeBodyInLastContract(string methodName, Ty { contentTypeIndex = i; } - else if (string.Equals(param.Name, contentParamName, StringComparison.OrdinalIgnoreCase)) + else if (IsLastContractBodyParameter(param)) { bodyIndex = i; } @@ -1350,6 +1383,16 @@ private static bool HasContentTypeBeforeBodyInLastContract(string methodName, Ty return false; } + private static bool IsLastContractBodyParameter(ParameterProvider parameter) + { + if (string.Equals(parameter.Name, "content", StringComparison.OrdinalIgnoreCase)) + { + return true; + } + + return parameter.Type.Name is "BinaryContent" or "BinaryData" or "RequestContent"; + } + internal static InputModelType GetSpreadParameterModel(InputParameter inputParam) { if (inputParam.Type is InputModelType model) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs index 7e637e07363..d9f0f9f2733 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs @@ -362,6 +362,79 @@ public async Task ContentTypeAfterBodyInLastContractView() Assert.AreEqual("contentType", methodParameters[2].Name); // contentType after body } + [Test] + public async Task ContentTypeOrderPreservedFromLastContractViewWithNamedBody() + { + var contentTypeEnum = InputFactory.StringEnum("ContentTypeEnum", + [("application/json", "application/json"), ("application/xml", "application/xml")], + isExtensible: true); + var contentTypeHeader = InputFactory.HeaderParameter( + "contentType", + contentTypeEnum, + isRequired: true, + isContentType: true, + serializedName: "Content-Type"); + var bodyParam = InputFactory.BodyParameter("schemaContent", InputPrimitiveType.String, isRequired: true); + var groupNameParam = InputFactory.PathParameter("groupName", InputPrimitiveType.String, isRequired: true); + var pathParam = InputFactory.PathParameter("schemaName", InputPrimitiveType.String, isRequired: true); + var methodGroupNameParam = InputFactory.MethodParameter( + "groupName", + InputPrimitiveType.String, + isRequired: true, + location: InputRequestLocation.Path); + var methodSchemaNameParam = InputFactory.MethodParameter( + "schemaName", + InputPrimitiveType.String, + isRequired: true, + location: InputRequestLocation.Path); + var methodContentTypeParam = InputFactory.MethodParameter( + "contentType", + contentTypeEnum, + isRequired: true, + serializedName: "Content-Type", + location: InputRequestLocation.Header); + var methodBodyParam = InputFactory.MethodParameter( + "schemaContent", + InputPrimitiveType.String, + isRequired: true, + location: InputRequestLocation.Body); + + var operation = InputFactory.Operation( + "RegisterSchema", + parameters: [groupNameParam, pathParam, contentTypeHeader, bodyParam]); + + var serviceMethod = InputFactory.BasicServiceMethod( + "RegisterSchema", + operation, + parameters: [methodGroupNameParam, methodSchemaNameParam, methodContentTypeParam, methodBodyParam]); + + var client = InputFactory.Client("TestClient", methods: [serviceMethod]); + + var generator = await MockHelpers.LoadMockGeneratorAsync( + clients: () => [client], + lastContractCompilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + + var clientProvider = generator.Object.OutputLibrary.TypeProviders.OfType().FirstOrDefault(); + Assert.IsNotNull(clientProvider); + Assert.IsNotNull(clientProvider!.LastContractView); + + var methodParameters = RestClientProvider.GetMethodParameters(serviceMethod, ScmMethodKind.Protocol, clientProvider!); + + Assert.AreEqual(4, methodParameters.Count); + Assert.AreEqual("groupName", methodParameters[0].Name); + Assert.AreEqual("schemaName", methodParameters[1].Name); + Assert.AreEqual("contentType", methodParameters[2].Name); + Assert.AreEqual("content", methodParameters[3].Name); + + var convenienceMethodParameters = RestClientProvider.GetMethodParameters(serviceMethod, ScmMethodKind.Convenience, clientProvider!); + + Assert.AreEqual(4, convenienceMethodParameters.Count); + Assert.AreEqual("groupName", convenienceMethodParameters[0].Name); + Assert.AreEqual("schemaName", convenienceMethodParameters[1].Name); + Assert.AreEqual("contentType", convenienceMethodParameters[2].Name); + Assert.AreEqual("schemaContent", convenienceMethodParameters[3].Name); + } + [Test] public async Task ParameterNamePreservedFromLastContractView() { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeOrderPreservedFromLastContractViewWithNamedBody/TestClient.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeOrderPreservedFromLastContractViewWithNamedBody/TestClient.cs new file mode 100644 index 00000000000..e7005ee3202 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeOrderPreservedFromLastContractViewWithNamedBody/TestClient.cs @@ -0,0 +1,20 @@ +#nullable disable + +using System; +using System.ClientModel; +using System.Threading; +using System.Threading.Tasks; + +namespace Sample +{ + public partial class TestClient + { + // This represents a previous contract where contentType appears before a named body parameter. + internal virtual Task RegisterSchemaAsync(string groupName, string schemaName, SchemaContentTypeValues contentType, BinaryData schemaContent, CancellationToken cancellationToken = default) { return null; } + internal virtual ClientResult RegisterSchema(string groupName, string schemaName, SchemaContentTypeValues contentType, BinaryData schemaContent, CancellationToken cancellationToken = default) { return null; } + } + + internal readonly partial struct SchemaContentTypeValues + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs index 9fb12dd028d..f1a7d79a102 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs @@ -54,7 +54,6 @@ private static void AddProviderInfrastructureReferences(HashSet referenc if (IsSerializationProvider(provider)) { AddMatchingName(references, "Optional", nodes); - AddMatchingName(references, "Utf8JsonRequestContent", nodes); AddMatchingName(references, "ModelSerializationExtensions", nodes); AddSerializationExtensionReferences(references, provider, nodes); } @@ -120,10 +119,6 @@ private static void AddSerializationExtensionReferences(HashSet referenc private static void AddMethodInfrastructureReferences(HashSet references, MethodProvider method, HashSet nodes) { AddReturnTypeInfrastructureReferences(references, method.Signature.ReturnType, nodes); - foreach (var parameter in method.Signature.Parameters) - { - AddRequestContentInfrastructureReferences(references, parameter.Type, nodes); - } } private static void AddReturnTypeInfrastructureReferences(HashSet references, CSharpType? returnType, HashSet nodes) @@ -154,25 +149,6 @@ private static void AddReturnTypeInfrastructureReferences(HashSet refere } } - private static void AddRequestContentInfrastructureReferences(HashSet references, CSharpType? type, HashSet nodes) - { - if (type == null) - { - return; - } - - if (string.Equals(type.Name, "RequestContent", StringComparison.Ordinal)) - { - AddMatchingName(references, "BinaryContentHelper", nodes); - AddMatchingName(references, "Utf8JsonRequestContent", nodes); - } - - foreach (var argument in type.Arguments) - { - AddRequestContentInfrastructureReferences(references, argument, nodes); - } - } - private static CSharpType? UnwrapTask(CSharpType? type) { var typeName = type == null ? null : StripGenericArity(type.Name); @@ -357,7 +333,7 @@ private static void AddProviderBodyDependencyType( } AddTypeReference(references, dependency, nodes); - if (includeSimpleNameReferences) + if (includeSimpleNameReferences && !string.IsNullOrEmpty(dependency.Namespace)) { AddMatchingName(references, dependency.Name, nodes); } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs index ca204b2b5b3..33166f81e8b 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs @@ -315,6 +315,11 @@ private static HashSet GetGeneratedInternalTypeDeclarations(IReadOnlyLis private static HashSet GetGeneratedPublicTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) => GetGeneratedTypeDeclarationsByLastContractAccessibility(providers, generatedTypeNames, TypeSignatureModifiers.Public); + private static HashSet GetGeneratedPublicTypeDeclarationsFromLastContract(IReadOnlyList providers, HashSet generatedTypeNames) + => HasApiBaselineDirectory() + ? new HashSet(StringComparer.Ordinal) + : GetGeneratedPublicTypeDeclarations(providers, generatedTypeNames); + private static HashSet GetGeneratedTypeDeclarationsByLastContractAccessibility( IReadOnlyList providers, HashSet generatedTypeNames, diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs index 9dd16e07a52..8a06ac948c3 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs @@ -57,8 +57,24 @@ private static bool IsAdditionalRootProvider(TypeProvider provider, HashSet providers) var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(generatedProviders, graph.Nodes); var apiBaselineGeneratedTypeRoots = GetApiBaselineGeneratedTypeRoots(graph.Nodes); customPublicRoots.UnionWith(apiBaselineGeneratedTypeRoots); - var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarations(generatedProviders, graph.Nodes); + var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarationsFromLastContract(generatedProviders, graph.Nodes); customPublicRoots.UnionWith(generatedPublicDeclarations); var customCodeRemovalRoots = GetCustomCodeGeneratedTypeRoots(generatedProviders, graph.Nodes); var customRemovalRoots = new HashSet(customCodeRemovalRoots, StringComparer.Ordinal); @@ -136,7 +136,7 @@ private static (HashSet InternalizeCandidates, HashSet Publicize var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(generatedProviders, graph.Nodes); var apiBaselineGeneratedTypeRoots = GetApiBaselineGeneratedTypeRoots(graph.Nodes); customPublicRoots.UnionWith(apiBaselineGeneratedTypeRoots); - var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarations(generatedProviders, graph.Nodes); + var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarationsFromLastContract(generatedProviders, graph.Nodes); customPublicRoots.UnionWith(generatedPublicDeclarations); var customInternalDeclarations = GetCustomCodeInternalGeneratedTypeDeclarations(generatedProviders, graph.Nodes); var generatedInternalDeclarations = GetGeneratedInternalTypeDeclarations(generatedProviders, graph.Nodes); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs index be99b9ae2f2..fa075f162b2 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; +using System.Collections.Generic; +using System.Reflection; using Microsoft.TypeSpec.Generator.Primitives; using Microsoft.TypeSpec.Generator.Providers; using Microsoft.TypeSpec.Generator.Tests.TestHelpers; @@ -48,5 +51,43 @@ public void NonRootKeptTypesAreWrittenWithoutRootingOtherTypes() Assert.IsTrue(ProviderReferenceMapAnalyzer.ShouldWriteProvider(context)); Assert.IsFalse(ProviderReferenceMapAnalyzer.ShouldWriteProvider(unusedModel)); } + + [Test] + public void NamespaceLessCustomCodeBodyDependencyDoesNotRootGeneratedTypeBySimpleName() + { + var customType = new BodyDependencyTestTypeProvider("CustomType", CreateNamedType("Error", string.Empty)); + var generatedError = new TestTypeProvider("Error", TypeSignatureModifiers.Public, ns: "Sample.Models"); + MockHelpers.LoadMockGenerator(createOutputLibrary: () => new TestOutputLibrary(customType, generatedError)); + CodeModelGenerator.Instance.AddTypeToKeep(customType.Type.FullyQualifiedName); + + ProviderReferenceMapAnalyzer.Analyze([customType, generatedError]); + + Assert.IsTrue(ProviderReferenceMapAnalyzer.ShouldWriteProvider(customType)); + Assert.IsFalse(ProviderReferenceMapAnalyzer.ShouldWriteProvider(generatedError)); + } + + private sealed class BodyDependencyTestTypeProvider : TestTypeProvider + { + private readonly CSharpType[] _bodyDependencyTypes; + + public BodyDependencyTestTypeProvider(string name, params CSharpType[] bodyDependencyTypes) + : base(name, TypeSignatureModifiers.Public) + { + _bodyDependencyTypes = bodyDependencyTypes; + } + + protected internal override IReadOnlyList BuildBodyDependencyTypes() => _bodyDependencyTypes; + } + + private static CSharpType CreateNamedType(string name, string ns) + { + var constructor = typeof(CSharpType).GetConstructor( + BindingFlags.Instance | BindingFlags.NonPublic, + binder: null, + [typeof(string), typeof(string), typeof(bool), typeof(bool), typeof(CSharpType), typeof(IReadOnlyList), typeof(bool), typeof(bool), typeof(CSharpType), typeof(Type)], + modifiers: null)!; + + return (CSharpType)constructor.Invoke([name, ns, false, false, null, new List(), true, false, null, null]); + } } } From 57e7b6aa908d22eda3e7a5b5e1a6771de68b2980 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Fri, 3 Jul 2026 07:33:22 +0000 Subject: [PATCH 14/19] fix(http-client-csharp): preserve internalize write behavior Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/CSharpGen.cs | 62 ++++++++++--------- ...iderReferenceMapAnalyzer.BodyReferences.cs | 6 +- .../ProviderReferenceMapAnalyzer.Helpers.cs | 7 --- .../ProviderReferenceMapAnalyzer.cs | 27 +++++--- .../ProviderReferenceMapAnalyzerTests.cs | 52 ++++++++++++++++ 5 files changed, 106 insertions(+), 48 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs index 35914c1f8d4..782a17ae75a 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs @@ -93,50 +93,56 @@ await GeneratedCodeWorkspace.LoadBaselineContract(), outputType.ProcessTypeForBackCompatibility(); } - generatedCodeWorkspace.ApplyPreWriteAccessibility(output.TypeProviders); - generatedCodeWorkspace.AnalyzeProviderReferenceMap(output.TypeProviders); - - foreach (var outputType in output.TypeProviders) + try { - if (!ProviderReferenceMapAnalyzer.ShouldWriteProvider(outputType)) - { - continue; - } + generatedCodeWorkspace.ApplyPreWriteAccessibility(output.TypeProviders); + generatedCodeWorkspace.AnalyzeProviderReferenceMap(output.TypeProviders); - if (outputType is ModelFactoryProvider && outputType.Methods.Count == 0) + foreach (var outputType in output.TypeProviders) { - continue; - } - - var writer = CodeModelGenerator.Instance.GetWriter(outputType); - generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); + if (!ProviderReferenceMapAnalyzer.ShouldWriteProvider(outputType)) + { + continue; + } - foreach (var serialization in outputType.SerializationProviders) - { - if (!ProviderReferenceMapAnalyzer.ShouldWriteProvider(serialization)) + if (outputType is ModelFactoryProvider && outputType.Methods.Count == 0) { continue; } - writer = CodeModelGenerator.Instance.GetWriter(serialization); + var writer = CodeModelGenerator.Instance.GetWriter(outputType); generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); + + foreach (var serialization in outputType.SerializationProviders) + { + if (!ProviderReferenceMapAnalyzer.ShouldWriteProvider(serialization)) + { + continue; + } + + writer = CodeModelGenerator.Instance.GetWriter(serialization); + generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); + } } - } - // Add all the generated files to the workspace - await Task.WhenAll(generateFilesTasks); + // Add all the generated files to the workspace + await Task.WhenAll(generateFilesTasks); - ProviderReferenceMapAnalyzer.RestorePreWriteModelFactoryMethods(); + ProviderReferenceMapAnalyzer.RestorePreWriteModelFactoryMethods(); - LoggingHelpers.LogElapsedTime("All generated types have been written into memory"); + LoggingHelpers.LogElapsedTime("All generated types have been written into memory"); - // Delete any old generated files - DeleteDirectory(generatedSourceOutputPath, _filesToKeep); + // Delete any old generated files + DeleteDirectory(generatedSourceOutputPath, _filesToKeep); - LoggingHelpers.LogElapsedTime("All old generated files have been deleted"); + LoggingHelpers.LogElapsedTime("All old generated files have been deleted"); - await generatedCodeWorkspace.PostProcessAsync(); - ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); + await generatedCodeWorkspace.PostProcessAsync(); + } + finally + { + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); + } var generatedFiles = new List<(string Name, string Text)>(); await foreach (var file in generatedCodeWorkspace.GetGeneratedFilesAsync()) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs index f1a7d79a102..84e50de2b23 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs @@ -37,11 +37,11 @@ private static void AddGeneratedBodyReferences(IReadOnlyList provi GetNonEnumStructuredBodyReferenceTypes(provider, graph.Nodes), graph.Nodes); AddProviderBodyDependencyTypes(graph.References[providerName], provider.BodyDependencyTypes, graph.Nodes); - AddProviderInfrastructureReferences(graph.References[providerName], provider, graph.Nodes); + AddProviderInfrastructureReferences(graph.References[providerName], provider, isSerializationProvider, graph.Nodes); } } - private static void AddProviderInfrastructureReferences(HashSet references, TypeProvider provider, HashSet nodes) + private static void AddProviderInfrastructureReferences(HashSet references, TypeProvider provider, bool isSerializationProvider, HashSet nodes) { AddMatchingName(references, "ProviderConstants", nodes); AddMatchingName(references, "TypeFormatters", nodes); @@ -51,7 +51,7 @@ private static void AddProviderInfrastructureReferences(HashSet referenc AddSerializationExtensionReferences(references, provider, nodes); } - if (IsSerializationProvider(provider)) + if (isSerializationProvider) { AddMatchingName(references, "Optional", nodes); AddMatchingName(references, "ModelSerializationExtensions", nodes); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs index 8a06ac948c3..b99e3ee5bb1 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs @@ -332,13 +332,6 @@ private static bool IsRequestHeaderExtensionsDependency(CSharpType? type) return false; } - private static bool IsSerializationProvider(TypeProvider provider) - { - var relativePath = provider.RelativeFilePath.Replace('\\', '/'); - return relativePath.EndsWith(".Serialization.cs", StringComparison.Ordinal) || - relativePath.EndsWith(".Serialization.Multipart.cs", StringComparison.Ordinal); - } - private static void AddInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) { if (type == null) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs index c174ed0a827..33730a259b9 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs @@ -111,21 +111,28 @@ public static void Analyze(IReadOnlyList providers) generatedInternalDeclarations, generatedDiscriminatorBaseNames); - // Body-only generated dependencies are needed to avoid deleting helper files, but they do - // not contribute to public API reachability for internalization. - AddGeneratedBodyReferences(providers, graph); - var removeCandidates = GetRemovalCandidates( - providers, - generatedProviders, - graph, - customRemovalRoots, - generatedDiscriminatorBaseNames); + var removeCandidates = new HashSet(StringComparer.Ordinal); + if (Configuration.UnreferencedTypesHandling == Configuration.UnreferencedTypesHandlingOption.RemoveOrInternalize) + { + // Body-only generated dependencies are needed to avoid deleting helper files, but they do + // not contribute to public API reachability for internalization. + AddGeneratedBodyReferences(providers, graph); + removeCandidates = GetRemovalCandidates( + providers, + generatedProviders, + graph, + customRemovalRoots, + generatedDiscriminatorBaseNames); + } _latestResult = new ProviderReferenceMapResult( internalizeCandidates, publicizeCandidates, removeCandidates); - RemoveMethodsFromModelFactory(GetSimpleNames(removeCandidates)); + if (Configuration.UnreferencedTypesHandling == Configuration.UnreferencedTypesHandlingOption.RemoveOrInternalize) + { + RemoveMethodsFromModelFactory(GetSimpleNames(removeCandidates)); + } } private static (HashSet InternalizeCandidates, HashSet PublicizeCandidates) GetPreWriteAccessibilityCandidates(IReadOnlyList providers) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs index fa075f162b2..6240720529e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs @@ -66,6 +66,43 @@ public void NamespaceLessCustomCodeBodyDependencyDoesNotRootGeneratedTypeBySimpl Assert.IsFalse(ProviderReferenceMapAnalyzer.ShouldWriteProvider(generatedError)); } + [Test] + public void InternalizeModeDoesNotRemoveUnreferencedProviders() + { + var context = new TestTypeProvider("SampleContext", TypeSignatureModifiers.Public); + var unusedModel = new TestTypeProvider("UnusedModel", TypeSignatureModifiers.Public); + MockHelpers.LoadMockGenerator( + createOutputLibrary: () => new TestOutputLibrary(context, unusedModel), + configuration: "{\"unreferenced-types-handling\":\"internalize\"}"); + CodeModelGenerator.Instance.AddTypeToKeep(context.Type.FullyQualifiedName); + + ProviderReferenceMapAnalyzer.Analyze([context, unusedModel]); + + Assert.IsTrue(ProviderReferenceMapAnalyzer.ShouldWriteProvider(context)); + Assert.IsTrue(ProviderReferenceMapAnalyzer.ShouldWriteProvider(unusedModel)); + Assert.IsEmpty(ProviderReferenceMapAnalyzer.LatestResult!.RemoveCandidates); + } + + [Test] + public void SerializationProviderInfrastructureRootsUseSerializationProviderRelationship() + { + var serializationProvider = new TestTypeProvider("SampleModelSerializer", TypeSignatureModifiers.Public); + var model = new ClientRootWithSerializationProvider("SampleModel", serializationProvider); + var optional = new TestTypeProvider("Optional", TypeSignatureModifiers.Public); + var modelSerializationExtensions = new TestTypeProvider("ModelSerializationExtensions", TypeSignatureModifiers.Public); + MockHelpers.LoadMockGenerator(createOutputLibrary: () => new TestOutputLibrary( + model, + serializationProvider, + optional, + modelSerializationExtensions)); + + ProviderReferenceMapAnalyzer.Analyze([model, serializationProvider, optional, modelSerializationExtensions]); + + Assert.IsTrue(ProviderReferenceMapAnalyzer.ShouldWriteProvider(serializationProvider)); + Assert.IsTrue(ProviderReferenceMapAnalyzer.ShouldWriteProvider(optional)); + Assert.IsTrue(ProviderReferenceMapAnalyzer.ShouldWriteProvider(modelSerializationExtensions)); + } + private sealed class BodyDependencyTestTypeProvider : TestTypeProvider { private readonly CSharpType[] _bodyDependencyTypes; @@ -79,6 +116,21 @@ public BodyDependencyTestTypeProvider(string name, params CSharpType[] bodyDepen protected internal override IReadOnlyList BuildBodyDependencyTypes() => _bodyDependencyTypes; } + private sealed class ClientRootWithSerializationProvider : TestTypeProvider + { + private readonly TypeProvider[] _serializationProviders; + + public ClientRootWithSerializationProvider(string name, params TypeProvider[] serializationProviders) + : base(name, TypeSignatureModifiers.Public) + { + _serializationProviders = serializationProviders; + } + + protected internal override bool IsClientProvider => true; + + protected override TypeProvider[] BuildSerializationProviders() => _serializationProviders; + } + private static CSharpType CreateNamedType(string name, string ns) { var constructor = typeof(CSharpType).GetConstructor( From 225204314ffb42d87ccca5e84fd33dbde6554993 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Fri, 3 Jul 2026 08:03:12 +0000 Subject: [PATCH 15/19] fix(http-client-csharp): avoid custom root name collisions Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- ...derReferenceMapAnalyzer.CustomCodeRoots.cs | 4 ++-- .../ProviderReferenceMapAnalyzerTests.cs | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs index 33166f81e8b..da2f411ec99 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs @@ -104,7 +104,7 @@ private static void AddCustomCodeViewGeneratedTypeRoot(HashSet roots, Ty { if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) { - AddMatchingName(roots, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); + AddMatchingName(roots, namedTypeSymbolProvider.MetadataName, generatedTypeNames); return; } @@ -284,7 +284,7 @@ private static HashSet GetCustomCodeInternalGeneratedTypeDeclarations(IR if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) { - AddMatchingName(declarations, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); + AddMatchingName(declarations, namedTypeSymbolProvider.MetadataName, generatedTypeNames); } else { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs index 6240720529e..d7c05ea08e5 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs @@ -4,8 +4,10 @@ using System; using System.Collections.Generic; using System.Reflection; +using System.Threading.Tasks; using Microsoft.TypeSpec.Generator.Primitives; using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Tests.Providers.NamedTypeSymbolProviders; using Microsoft.TypeSpec.Generator.Tests.TestHelpers; using NUnit.Framework; @@ -103,6 +105,25 @@ public void SerializationProviderInfrastructureRootsUseSerializationProviderRela Assert.IsTrue(ProviderReferenceMapAnalyzer.ShouldWriteProvider(modelSerializationExtensions)); } + [Test] + public async Task InternalCustomizationTypeDoesNotInternalizeGeneratedTypeWithSameSimpleName() + { + var customCompilation = CompilationHelper.LoadCompilation( + [new TestTypeProvider("Error", TypeSignatureModifiers.Internal, ns: "Custom.Models")]); + var context = new TestTypeProvider("SampleContext", TypeSignatureModifiers.Public); + var generatedError = new TestTypeProvider("Error", TypeSignatureModifiers.Public, ns: "Generated.Models"); + await MockHelpers.LoadMockGeneratorAsync( + createOutputLibrary: () => new TestOutputLibrary(context, generatedError), + compilation: () => Task.FromResult(customCompilation)); + CodeModelGenerator.Instance.AddTypeToKeep(context.Type.FullyQualifiedName); + CodeModelGenerator.Instance.AddTypeToKeep(generatedError.Type.FullyQualifiedName); + + ProviderReferenceMapAnalyzer.ApplyPreWriteAccessibility([context, generatedError]); + + Assert.IsTrue(generatedError.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + Assert.IsFalse(generatedError.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal)); + } + private sealed class BodyDependencyTestTypeProvider : TestTypeProvider { private readonly CSharpType[] _bodyDependencyTypes; From 5d59307f6ea9cb8871fc74ceb68633249a90bf0f Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Fri, 3 Jul 2026 08:42:29 +0000 Subject: [PATCH 16/19] fix(http-client-csharp): restore reviewed compatibility paths Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/Providers/RestClientProvider.cs | 86 +++++++++++++++++-- .../src/InputTypes/InputModelTypeUsage.cs | 1 + .../src/LibraryVisitor.cs | 4 + .../src/Providers/TypeProvider.cs | 65 ++++++++++++++ 4 files changed, 149 insertions(+), 7 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs index 5995ac3d5e5..90076185b6e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs @@ -590,6 +590,18 @@ private static MethodBodyStatement BuildAppendQueryStatement( { if (paramType?.IsCollection != true) { + // A model-typed query parameter marked with `explode` must be expanded into one query + // entry per property (RFC 6570 form explode, e.g. `?field=status&value=active`) rather + // than serialized via the object's ToString (which previously produced the type name). + if (inputQueryParameter.Explode && inputQueryParameter.Type is InputModelType inputModel) + { + var explodeStatement = BuildExplodeModelQueryStatement(uri, inputModel, valueExpression); + if (explodeStatement != null) + { + return explodeStatement; + } + } + var toStringExpression = GetQueryParameterStringExpression(paramType, valueExpression, serializationFormat); return uri.AppendQuery(Literal(inputQueryParameter.SerializedName), toStringExpression, true).Terminate(); } @@ -646,6 +658,70 @@ private static MethodBodyStatement BuildAppendQueryStatement( } } + /// + /// Builds the statements for a model-typed query parameter that uses form-style `explode`. + /// Each (simple) property of the model is emitted as its own query entry using the property's + /// wire name (RFC 6570 form explode, e.g. ?field=status&value=active). + /// Returns null when the model contains a property that is not a simple scalar/enum + /// (e.g. a nested object or a collection), in which case the caller falls back to the default + /// handling. Nested/complex expansion is tracked separately (see issue #11123). + /// + private static MethodBodyStatement? BuildExplodeModelQueryStatement( + ScopedApi uri, + InputModelType inputModel, + ValueExpression valueExpression) + { + var modelProvider = ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(inputModel); + if (modelProvider is null) + { + return null; + } + + var properties = modelProvider.CanonicalView.Properties; + if (properties.Count == 0) + { + return null; + } + + // Only expand when every property is a simple scalar or enum. Nested objects and + // collections are not defined by RFC 6570 form explode and require a separate design + // decision, so we fall back to the default handling for those. + foreach (var property in properties) + { + if (property.WireInfo is null || + property.Type.IsCollection || + (!property.Type.IsFrameworkType && !property.Type.IsEnum)) + { + return null; + } + } + + var statements = new List(); + foreach (var property in properties) + { + var propertyAccess = valueExpression.Property(property.Name); + var propertyType = property.Type; + + ValueExpression convertedValue = propertyType.IsEnum + ? propertyType.ToSerial(propertyAccess).ConvertToString() + : GetQueryParameterStringExpression(propertyType, propertyAccess, property.SerializationFormat); + + MethodBodyStatement appendStatement = + uri.AppendQuery(Literal(property.WireInfo!.SerializedName), convertedValue, true).Terminate(); + + if (!property.WireInfo.IsRequired || + propertyType.IsNullable || + (propertyType is { IsValueType: false, IsFrameworkType: true } && propertyType.FrameworkType != typeof(string))) + { + appendStatement = BuildQueryOrHeaderOrPathParameterNullCheck(propertyType, propertyAccess, appendStatement); + } + + statements.Add(appendStatement); + } + + return statements; + } + private static IfStatement BuildQueryOrHeaderOrPathParameterNullCheck( CSharpType? parameterType, ValueExpression valueExpression, @@ -884,7 +960,7 @@ private static void AppendLiteralSegment(ScopedApi uri, string literal, List paramMap, InputOperation operation, InputParameter inputParam, out CSharpType? type, out SerializationFormat? serializationFormat, out ValueExpression? valueExpression) { - type = IsGeneratedContentTypeMethodParameter(inputParam) + type = IsContentTypeParameter(inputParam, includeInputHeaderParameter: false) ? null : ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(inputParam.Type); serializationFormat = null; @@ -1319,15 +1395,11 @@ private static bool HasLiteralContentTypeHeader(InputOperation operation) return false; } - private static bool IsContentTypeParameter(InputParameter parameter) => - parameter is InputHeaderParameter { IsContentType: true } || + private static bool IsContentTypeParameter(InputParameter parameter, bool includeInputHeaderParameter = true) => + includeInputHeaderParameter && parameter is InputHeaderParameter { IsContentType: true } || parameter is InputMethodParameter { Location: InputRequestLocation.Header } && string.Equals(parameter.SerializedName, "Content-Type", StringComparison.OrdinalIgnoreCase); - private static bool IsGeneratedContentTypeMethodParameter(InputParameter parameter) => - parameter is InputMethodParameter { Location: InputRequestLocation.Header } && - string.Equals(parameter.SerializedName, "Content-Type", StringComparison.OrdinalIgnoreCase); - /// /// Checks if the last contract view contains a method matching the given name where /// a "contentType" parameter appears before the body parameter. diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/src/InputTypes/InputModelTypeUsage.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/src/InputTypes/InputModelTypeUsage.cs index f96ecad090e..52d9c8c0bc8 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/src/InputTypes/InputModelTypeUsage.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/src/InputTypes/InputModelTypeUsage.cs @@ -22,5 +22,6 @@ public enum InputModelTypeUsage LroInitial = 2048, LroPolling = 4096, LroFinalEnvelope = 8192, + External = 16384, } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs index 6df7ccc9758..e39b306ecfb 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs @@ -173,6 +173,8 @@ protected internal virtual void VisitLibrary(OutputLibrary library) return constructor; } + internal ConstructorProvider? VisitConstructorProvider(ConstructorProvider constructor) => VisitConstructor(constructor); + /// /// Visits a and returns a possibly modified version of it. /// @@ -306,5 +308,7 @@ protected internal virtual FinallyExpression VisitFinallyExpression(FinallyExpre { return field; } + + internal FieldProvider? VisitFieldProvider(FieldProvider field) => VisitField(field); } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs index bbaf3e54d24..6438e7bf333 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs @@ -765,10 +765,75 @@ internal void ProcessTypeForBackCompatibility() { _enumValues = updatedEnumValues; } + + // Back-compatibility processing intentionally runs after the library visitor pass so + // that the contract comparison uses the final, post-visitor member signatures (otherwise + // we could incorrectly decide whether a back-compat member is needed). As a result, any + // members synthesized above (e.g. back-compat overloads) have not been visited yet. Run + // only those newly-added members through the visitors now so visitor transforms apply to + // them as well, without re-visiting members that were already visited during the main pass. + if (newMethods != null) + { + newMethods = VisitNewMembers(newMethods, Methods, static (member, visitor) => member.Accept(visitor)); + } + if (newConstructors != null) + { + newConstructors = VisitNewMembers(newConstructors, Constructors, static (member, visitor) => visitor.VisitConstructorProvider(member)); + } + if (newFields != null) + { + newFields = VisitNewMembers(newFields, Fields, static (member, visitor) => visitor.VisitFieldProvider(member)); + } + Update(fields: newFields, methods: newMethods, constructors: newConstructors); } } + // Runs newly-added back-compatibility members through every registered visitor while leaving + // members that were already visited during the main visitor pass untouched. Membership in the + // already-visited set is determined by reference identity against the pre-Update collection. + private static IReadOnlyList VisitNewMembers( + IEnumerable allMembers, + IReadOnlyList alreadyVisited, + Func visit) + where T : class + { + var visitors = CodeModelGenerator.Instance.Visitors; + var materialized = allMembers as IReadOnlyList ?? [.. allMembers]; + if (visitors.Count == 0) + { + return materialized; + } + + var alreadyVisitedSet = new HashSet(alreadyVisited, ReferenceEqualityComparer.Instance); + var result = new List(materialized.Count); + foreach (var member in materialized) + { + if (alreadyVisitedSet.Contains(member)) + { + result.Add(member); + continue; + } + + T? visited = member; + foreach (var visitor in visitors) + { + visited = visit(visited, visitor); + if (visited == null) + { + break; + } + } + + if (visited != null) + { + result.Add(visited); + } + } + + return result; + } + protected internal virtual IReadOnlyList? BuildEnumValuesForBackCompatibility(IReadOnlyList originalEnumValues) => null; From 16b9454d7caa1085eb71b65f1010c24f819a45ea Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Fri, 3 Jul 2026 08:53:17 +0000 Subject: [PATCH 17/19] fix(http-client-csharp): address reference map review cleanup Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/Providers/ClientProvider.cs | 1 - .../src/Providers/TypeProvider.cs | 2 -- ...ProviderReferenceMapAnalyzer.BodyReferences.cs | 2 +- .../ProviderReferenceMapAnalyzer.Helpers.cs | 15 ++++++++++++++- .../ProviderReferenceMapResult.cs | 0 .../ProviderReferenceMapAnalyzerTests.cs | 8 +++----- 6 files changed, 18 insertions(+), 10 deletions(-) rename packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/{PostProcessing => ReferenceMap}/ProviderReferenceMapResult.cs (100%) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs index 55fee64b30f..6aa3cb10d18 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs @@ -43,7 +43,6 @@ private record ApiVersionFields(FieldProvider Field, PropertyProvider? Correspon private const string ClientSuffix = "Client"; private readonly FormattableString _publicCtorDescription; private readonly InputClient _inputClient; - protected override bool IsClientProvider => true; internal InputClient InputClient => _inputClient; private readonly InputAuth? _inputAuth; private readonly ParameterProvider _endpointParameter; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs index 6438e7bf333..6e8d8c841e5 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs @@ -311,8 +311,6 @@ private IReadOnlyList ApplyCustomizationFilter(IEnumerable SignatureDependencyTypes => _signatureDependencyTypes ??= BuildSignatureDependencyTypes(); protected internal virtual IReadOnlyList BuildSignatureDependencyTypes() => []; - protected internal virtual bool IsClientProvider => false; - protected internal virtual bool IncludeGeneratedBodyReferences => false; private IReadOnlyList? _attributes; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs index 84e50de2b23..dee4600d12e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs @@ -367,7 +367,7 @@ private static bool IsGeneratedBodyReferenceCandidate(TypeProvider provider, boo return true; } - return provider.IsClientProvider || + return IsClientProvider(provider) || isSerializationProvider || provider.IncludeGeneratedBodyReferences || provider.HelperDependencyTypes.Count > 0 || diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs index b99e3ee5bb1..b8d38c34e6b 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs @@ -41,9 +41,22 @@ private static bool IsKeptName(string providerName, string simpleName, HashSet - provider.IsClientProvider && + IsClientProvider(provider) && (!publicOnly || !HasApiBaselineDirectory() && provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + private static bool IsClientProvider(TypeProvider provider) + { + for (var type = provider.GetType(); type != null && type != typeof(TypeProvider); type = type.BaseType) + { + if (string.Equals(type.Name, "ClientProvider", StringComparison.Ordinal)) + { + return true; + } + } + + return false; + } + private static bool IsAdditionalRootProvider(TypeProvider provider, HashSet roots, HashSet nodes) { if (provider.DeclaringTypeProvider != null || !IsKept(provider.Type, roots, nodes)) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapResult.cs similarity index 100% rename from packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs rename to packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapResult.cs diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs index d7c05ea08e5..531632a971f 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs @@ -89,7 +89,7 @@ public void InternalizeModeDoesNotRemoveUnreferencedProviders() public void SerializationProviderInfrastructureRootsUseSerializationProviderRelationship() { var serializationProvider = new TestTypeProvider("SampleModelSerializer", TypeSignatureModifiers.Public); - var model = new ClientRootWithSerializationProvider("SampleModel", serializationProvider); + var model = new ClientProvider("SampleModel", serializationProvider); var optional = new TestTypeProvider("Optional", TypeSignatureModifiers.Public); var modelSerializationExtensions = new TestTypeProvider("ModelSerializationExtensions", TypeSignatureModifiers.Public); MockHelpers.LoadMockGenerator(createOutputLibrary: () => new TestOutputLibrary( @@ -137,18 +137,16 @@ public BodyDependencyTestTypeProvider(string name, params CSharpType[] bodyDepen protected internal override IReadOnlyList BuildBodyDependencyTypes() => _bodyDependencyTypes; } - private sealed class ClientRootWithSerializationProvider : TestTypeProvider + private sealed class ClientProvider : TestTypeProvider { private readonly TypeProvider[] _serializationProviders; - public ClientRootWithSerializationProvider(string name, params TypeProvider[] serializationProviders) + public ClientProvider(string name, params TypeProvider[] serializationProviders) : base(name, TypeSignatureModifiers.Public) { _serializationProviders = serializationProviders; } - protected internal override bool IsClientProvider => true; - protected override TypeProvider[] BuildSerializationProviders() => _serializationProviders; } From 02ff7cd645e77933653b465305719eeece368a38 Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Fri, 3 Jul 2026 09:04:02 +0000 Subject: [PATCH 18/19] fix(http-client-csharp): keep custom signature references public Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- ...iderReferenceMapAnalyzer.BodyReferences.cs | 14 +++++-- ...derReferenceMapAnalyzer.CustomCodeRoots.cs | 32 +++++++++++++-- .../ProviderReferenceMapAnalyzer.Helpers.cs | 15 +++++++ .../ProviderReferenceMapAnalyzerTests.cs | 39 +++++++++++++++++++ 4 files changed, 93 insertions(+), 7 deletions(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs index dee4600d12e..b5a9f8ffcdb 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs @@ -313,11 +313,12 @@ private static void AddProviderBodyDependencyTypes( HashSet references, IReadOnlyList dependencies, HashSet nodes, - bool includeSimpleNameReferences = false) + bool includeSimpleNameReferences = false, + bool includeUnqualifiedSimpleNameReferences = false) { foreach (var dependency in dependencies) { - AddProviderBodyDependencyType(references, dependency, nodes, includeSimpleNameReferences); + AddProviderBodyDependencyType(references, dependency, nodes, includeSimpleNameReferences, includeUnqualifiedSimpleNameReferences); } } @@ -325,7 +326,8 @@ private static void AddProviderBodyDependencyType( HashSet references, CSharpType? dependency, HashSet nodes, - bool includeSimpleNameReferences) + bool includeSimpleNameReferences, + bool includeUnqualifiedSimpleNameReferences) { if (dependency == null) { @@ -337,11 +339,15 @@ private static void AddProviderBodyDependencyType( { AddMatchingName(references, dependency.Name, nodes); } + else if (includeUnqualifiedSimpleNameReferences && string.IsNullOrEmpty(dependency.Namespace)) + { + AddUnambiguousMatchingName(references, dependency.Name, nodes); + } AddMatchingName(references, $"{dependency.Name}Extensions", nodes); foreach (var argument in dependency.Arguments) { - AddProviderBodyDependencyType(references, argument, nodes, includeSimpleNameReferences); + AddProviderBodyDependencyType(references, argument, nodes, includeSimpleNameReferences, includeUnqualifiedSimpleNameReferences); } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs index da2f411ec99..f3d98b0012d 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs @@ -104,7 +104,7 @@ private static void AddCustomCodeViewGeneratedTypeRoot(HashSet roots, Ty { if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) { - AddMatchingName(roots, namedTypeSymbolProvider.MetadataName, generatedTypeNames); + AddExactMetadataNameMatch(roots, namedTypeSymbolProvider.MetadataName, generatedTypeNames); return; } @@ -138,7 +138,7 @@ private static void AddCustomizationBackedExtensionRoots(HashSet roots, private static void AddCustomCodeViewRoots(HashSet roots, TypeProvider customCodeView, HashSet generatedTypeNames, bool publicOnly) { AddTypeReference(roots, customCodeView.BaseType, generatedTypeNames); - AddProviderBodyDependencyTypes(roots, customCodeView.SignatureDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true); + AddProviderBodyDependencyTypes(roots, customCodeView.SignatureDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true, includeUnqualifiedSimpleNameReferences: true); if (!publicOnly) { AddProviderBodyDependencyTypes(roots, customCodeView.BodyDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true); @@ -284,7 +284,7 @@ private static HashSet GetCustomCodeInternalGeneratedTypeDeclarations(IR if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) { - AddMatchingName(declarations, namedTypeSymbolProvider.MetadataName, generatedTypeNames); + AddExactMetadataNameMatch(declarations, namedTypeSymbolProvider.MetadataName, generatedTypeNames); } else { @@ -295,6 +295,32 @@ private static HashSet GetCustomCodeInternalGeneratedTypeDeclarations(IR return declarations; } + private static void AddExactMetadataNameMatch(HashSet target, string metadataName, HashSet generatedTypeNames) + { + var normalizedName = NormalizeMetadataTypeName(metadataName); + if (!string.IsNullOrEmpty(normalizedName) && generatedTypeNames.Contains(normalizedName)) + { + target.Add(normalizedName); + } + } + + private static string NormalizeMetadataTypeName(string metadataName) + { + var arrayIndex = metadataName.IndexOf('[', StringComparison.Ordinal); + if (arrayIndex > 0) + { + metadataName = metadataName.Substring(0, arrayIndex); + } + + var genericIndex = metadataName.IndexOf('<', StringComparison.Ordinal); + if (genericIndex > 0) + { + metadataName = metadataName.Substring(0, genericIndex); + } + + return metadataName; + } + private static HashSet GetGeneratedPersistableModelProxyTypeNames(IReadOnlyList providers, HashSet generatedTypeNames) { var proxyTypes = new HashSet(StringComparer.Ordinal); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs index b8d38c34e6b..523688c6ea4 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs @@ -417,6 +417,21 @@ private static void AddMatchingName(HashSet target, string name, HashSet } } + private static void AddUnambiguousMatchingName(HashSet target, string name, HashSet nodes) + { + if (nodes.Contains(name)) + { + target.Add(name); + return; + } + + var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); + if (simpleNameLookup.TryGetValue(name, out var matches) && matches.Length == 1) + { + target.Add(matches[0]); + } + } + private static void AddMatchingNamesWithSimpleNameSuffix(HashSet target, string suffix, HashSet nodes) { foreach (var node in nodes) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs index 531632a971f..30801b46c7c 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs @@ -124,6 +124,19 @@ await MockHelpers.LoadMockGeneratorAsync( Assert.IsFalse(generatedError.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal)); } + [Test] + public void PublicCustomCodeArraySignatureKeepsGeneratedTypePublic() + { + var customCodeView = new SignatureDependencyTestTypeProvider("PublicCustomApi", TypeSignatureModifiers.Public, CreateNamedType("GeneratedModel", string.Empty)); + var generatedModel = new CustomizableTestTypeProvider("GeneratedModel", TypeSignatureModifiers.Public, customCodeView, ns: "Generated.Models"); + MockHelpers.LoadMockGenerator(createOutputLibrary: () => new TestOutputLibrary(generatedModel)); + + ProviderReferenceMapAnalyzer.ApplyPreWriteAccessibility([generatedModel]); + + Assert.IsTrue(generatedModel.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + Assert.IsFalse(generatedModel.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal)); + } + private sealed class BodyDependencyTestTypeProvider : TestTypeProvider { private readonly CSharpType[] _bodyDependencyTypes; @@ -137,6 +150,32 @@ public BodyDependencyTestTypeProvider(string name, params CSharpType[] bodyDepen protected internal override IReadOnlyList BuildBodyDependencyTypes() => _bodyDependencyTypes; } + private sealed class SignatureDependencyTestTypeProvider : TestTypeProvider + { + private readonly CSharpType[] _signatureDependencyTypes; + + public SignatureDependencyTestTypeProvider(string name, TypeSignatureModifiers declarationModifiers, params CSharpType[] signatureDependencyTypes) + : base(name, declarationModifiers) + { + _signatureDependencyTypes = signatureDependencyTypes; + } + + protected internal override IReadOnlyList BuildSignatureDependencyTypes() => _signatureDependencyTypes; + } + + private sealed class CustomizableTestTypeProvider : TestTypeProvider + { + private readonly TypeProvider _customCodeView; + + public CustomizableTestTypeProvider(string name, TypeSignatureModifiers declarationModifiers, TypeProvider customCodeView, string ns) + : base(name, declarationModifiers, ns: ns) + { + _customCodeView = customCodeView; + } + + private protected override TypeProvider? BuildCustomCodeView(string? generatedTypeName = default, string? generatedTypeNamespace = default) => _customCodeView; + } + private sealed class ClientProvider : TestTypeProvider { private readonly TypeProvider[] _serializationProviders; From 38fa84983ae4a0968dc73d51141436c306bfc67c Mon Sep 17 00:00:00 2001 From: Wei Hu Date: Fri, 3 Jul 2026 11:38:47 +0000 Subject: [PATCH 19/19] Fix custom body reference pruning Preserve unqualified generated body dependencies referenced from customization code so provider reference-map pruning does not remove helper types that are only used in custom method bodies. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../ProviderReferenceMapAnalyzer.CustomCodeRoots.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs index f3d98b0012d..f1f18bdba1f 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs @@ -141,7 +141,7 @@ private static void AddCustomCodeViewRoots(HashSet roots, TypeProvider c AddProviderBodyDependencyTypes(roots, customCodeView.SignatureDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true, includeUnqualifiedSimpleNameReferences: true); if (!publicOnly) { - AddProviderBodyDependencyTypes(roots, customCodeView.BodyDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true); + AddProviderBodyDependencyTypes(roots, customCodeView.BodyDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true, includeUnqualifiedSimpleNameReferences: true); AddAttributes(roots, customCodeView.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); AddMatchingName(roots, $"{GetCustomCodeViewSimpleName(customCodeView)}Extensions", generatedTypeNames); }