diff --git a/packages/http-client-csharp/emitter/test/Unit/utils.test.ts b/packages/http-client-csharp/emitter/test/Unit/utils.test.ts index f5d42ff505f..31a90d06d24 100644 --- a/packages/http-client-csharp/emitter/test/Unit/utils.test.ts +++ b/packages/http-client-csharp/emitter/test/Unit/utils.test.ts @@ -12,14 +12,15 @@ import { typeSpecCompile, } from "./utils/test-util.js"; +vi.mock("child_process", () => ({ + spawn: vi.fn(), +})); + describe("execCSharpGenerator tests", () => { let spawnMock: any; let sdkContext: CSharpEmitterContext; beforeEach(async () => { - vi.restoreAllMocks(); - vi.mock("child_process", () => ({ - spawn: vi.fn(), - })); + vi.clearAllMocks(); const runner = await createEmitterTestHost(); const program = await typeSpecCompile(``, runner); const context = createEmitterContext(program); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs index f8ee8744e32..55fee64b30f 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientProvider.cs @@ -43,6 +43,7 @@ private record ApiVersionFields(FieldProvider Field, PropertyProvider? Correspon private const string ClientSuffix = "Client"; private readonly FormattableString _publicCtorDescription; private readonly InputClient _inputClient; + protected override bool IsClientProvider => true; internal InputClient InputClient => _inputClient; private readonly InputAuth? _inputAuth; private readonly ParameterProvider _endpointParameter; @@ -426,6 +427,82 @@ private IReadOnlyList GetClientParameters() protected override string BuildName() => _inputClient.IsExactName ? _inputClient.Name : _inputClient.Name.ToIdentifierName(); + protected override IReadOnlyList BuildHelperDependencyTypes() + { + foreach (var method in Methods.OfType()) + { + if (method.BodyStatements != null) + { + return [new CancellationTokenExtensionsDefinition().Type, new ClientPipelineExtensionsDefinition().Type]; + } + } + + return []; + } + + protected override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new List(); + foreach (var method in Methods.OfType()) + { + if (method.BodyStatements == null) + { + continue; + } + + if (method.CollectionDefinition != null) + { + dependencies.Add(method.CollectionDefinition.Type); + } + + if (method.ServiceMethod == null) + { + continue; + } + + AddInputTypeDependency(dependencies, method.ServiceMethod.Response.Type); + AddInputTypeDependency(dependencies, method.ServiceMethod.Exception?.Type); + foreach (var parameter in method.ServiceMethod.Parameters) + { + if (IsContentTypeParameter(parameter)) + { + continue; + } + + AddInputTypeDependency(dependencies, parameter.Type); + } + + foreach (var parameter in method.ServiceMethod.Operation.Parameters) + { + if (IsContentTypeParameter(parameter)) + { + continue; + } + + AddInputTypeDependency(dependencies, parameter.Type); + } + + // Operation responses are input metadata. The generated method signature and body + // dependencies above capture the response types that are actually used. + } + + return dependencies; + } + + private static bool IsContentTypeParameter(InputParameter parameter) => + parameter is InputHeaderParameter { IsContentType: true } || + parameter is InputMethodParameter { Location: InputRequestLocation.Header } && + string.Equals(parameter.SerializedName, "Content-Type", StringComparison.OrdinalIgnoreCase); + + private static void AddInputTypeDependency(List dependencies, InputType? inputType) + { + var type = inputType == null ? null : ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(inputType); + if (type != null) + { + dependencies.Add(type); + } + } + protected override FieldProvider[] BuildFields() { List fields = [EndpointField]; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientUriBuilderDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientUriBuilderDefinition.cs index 92e0cf23a67..4dc5c283cc7 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientUriBuilderDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/ClientUriBuilderDefinition.cs @@ -28,6 +28,7 @@ internal sealed class ClientUriBuilderDefinition : TypeProvider private readonly FieldProvider _uriBuilderField; private readonly FieldProvider _pathAndQueryField; private readonly FieldProvider _pathLengthField; + protected override bool IncludeGeneratedBodyReferences => true; private PropertyProvider? _uriBuilderProperty; private PropertyProvider UriBuilderProperty => _uriBuilderProperty ??= new( diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs index 62f28d86c8d..b734d9a6529 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/CollectionResultDefinition.cs @@ -217,6 +217,22 @@ private bool HasPagingOperationNameCollision(string operationName) protected override TypeSignatureModifiers BuildDeclarationModifiers() => TypeSignatureModifiers.Internal | TypeSignatureModifiers.Partial | TypeSignatureModifiers.Class; + protected override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new List { Client.Type, ResponseModelType, NextPagePropertyType }; + if (ItemModelType != null) + { + dependencies.Add(ItemModelType); + } + + foreach (var field in RequestFields) + { + dependencies.Add(field.Type); + } + + return dependencies; + } + protected override FieldProvider[] BuildFields() => [ClientField, .. RequestFields]; protected override CSharpType[] BuildImplements() => diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs index 5d76b8f44a1..2b50fff372e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs @@ -67,7 +67,7 @@ private MethodProvider BuildXmlModelWriteCoreMethod() MethodSignatureModifiers modifiers = _isStruct ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Internal | MethodSignatureModifiers.Virtual; - if (_shouldOverrideXmlMethods) + if (_shouldOverrideMethods) { modifiers = MethodSignatureModifiers.Internal | MethodSignatureModifiers.Override; } @@ -81,7 +81,7 @@ private MethodProvider BuildXmlModelWriteCoreMethod() private MethodBodyStatement[] BuildXmlModelWriteCoreMethodBody() { - var categorizedProperties = _shouldOverrideXmlMethods + var categorizedProperties = _shouldOverrideMethods ? CategorizedXmlProperties : AllCategorizedXmlProperties; var statements = new List @@ -90,7 +90,7 @@ private MethodBodyStatement[] BuildXmlModelWriteCoreMethodBody() MethodBodyStatement.EmptyLine }; - if (_shouldOverrideXmlMethods) + if (_shouldOverrideMethods) { statements.Add(Base.Invoke(XmlModelWriteCoreMethodName, _xmlWriterParameter, _serializationOptionsParameter).Terminate()); } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index 8202a2af405..97532886c40 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -53,16 +53,10 @@ public partial class MrwSerializationTypeDefinition : TypeProvider private readonly ScopedApi _mrwOptionsParameterSnippet; private readonly ScopedApi _jsonElementParameterSnippet; private readonly ScopedApi _isNotEqualToWireConditionSnippet; - // These interface types depend on _model.Type. Build them lazily so we do not cache a - // CSharpType before delayed base model resolution has updated the model's inheritance. - private CSharpType? _jsonModelTInterfaceValue; - private CSharpType _jsonModelTInterface => _jsonModelTInterfaceValue ??= new CSharpType(typeof(IJsonModel<>), SerializationInterfaceType.Type); - private CSharpType? _jsonModelObjectInterface; - private CSharpType? JsonModelObjectInterface => _isStruct ? _jsonModelObjectInterface ??= (CSharpType)typeof(IJsonModel) : null; - private CSharpType? _persistableModelTInterfaceValue; - private CSharpType _persistableModelTInterface => _persistableModelTInterfaceValue ??= new CSharpType(typeof(IPersistableModel<>), SerializationInterfaceType.Type); - private CSharpType? _persistableModelObjectInterface; - private CSharpType? PersistableModelObjectInterface => _isStruct ? _persistableModelObjectInterface ??= (CSharpType)typeof(IPersistableModel) : null; + private readonly CSharpType _jsonModelTInterface; + private readonly CSharpType? _jsonModelObjectInterface; + private readonly CSharpType _persistableModelTInterface; + private readonly CSharpType? _persistableModelObjectInterface; private readonly ModelProvider _model; private readonly InputModelType _inputModel; private readonly FieldProvider? _rawDataField; @@ -73,20 +67,10 @@ public partial class MrwSerializationTypeDefinition : TypeProvider private readonly bool _supportsXml; private ConstructorProvider? _serializationConstructor; // Flag to determine if the model should override the serialization methods - private bool? _shouldOverrideMethods; - private bool ShouldOverrideMethods => _shouldOverrideMethods ??= _model.BaseModelProvider != null && !_isStruct; - private bool? _shouldSkipSerializationMethodOverrides; - private bool ShouldSkipSerializationMethodOverrides => _shouldSkipSerializationMethodOverrides ??= ShouldSkipDerivedSerializationMethodOverrides(_model.BaseModelProvider); - private readonly bool _shouldOverrideXmlMethods; + private readonly bool _shouldOverrideMethods; + private readonly bool _shouldSkipDerivedSerializationMethodOverrides; private readonly Lazy _additionalProperties; - // Unknown discriminator models use their base model as the serialization interface type. - // This can also touch model.Type, so defer it until serialization method/interface emission. - private TypeProvider SerializationInterfaceType => _serializationInterfaceType ??= _inputModel.IsUnknownDiscriminatorModel - ? ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(_inputModel.BaseModel!)! - : _model; - private TypeProvider? _serializationInterfaceType; - private CSharpType RootType => _rootType ??= GetRootModelType(); private CSharpType? _rootType; @@ -100,10 +84,17 @@ public MrwSerializationTypeDefinition(InputModelType inputModel, ModelProvider m _isStruct = _model.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Struct); _supportsXml = inputModel.Usage.HasFlag(InputModelTypeUsage.Xml); _supportsJson = inputModel.Usage.HasFlag(InputModelTypeUsage.Json) || !_supportsXml; - _shouldOverrideXmlMethods = _model.BaseModelProvider != null && !_isStruct; + // Initialize the serialization interfaces + var interfaceType = inputModel.IsUnknownDiscriminatorModel ? ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(inputModel.BaseModel!)! : _model; + _jsonModelTInterface = new CSharpType(typeof(IJsonModel<>), interfaceType.Type); + _jsonModelObjectInterface = _isStruct ? (CSharpType)typeof(IJsonModel) : null; + _persistableModelTInterface = new CSharpType(typeof(IPersistableModel<>), interfaceType.Type); + _persistableModelObjectInterface = _isStruct ? (CSharpType)typeof(IPersistableModel) : null; _rawDataField = _model.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); _additionalBinaryDataProperty = new(GetAdditionalBinaryDataPropertiesProp); _additionalProperties = new(() => [.. _model.Properties.Where(p => p.IsAdditionalProperties)]); + _shouldOverrideMethods = _model.BaseModelProvider != null && !_isStruct; + _shouldSkipDerivedSerializationMethodOverrides = ShouldSkipDerivedSerializationMethodOverrides(_model.BaseModelProvider); _utf8JsonWriterSnippet = _utf8JsonWriterParameter.As(); _mrwOptionsParameterSnippet = _serializationOptionsParameter.As(); _jsonElementParameterSnippet = _jsonElementDeserializationParam.As(); @@ -126,6 +117,10 @@ public MrwSerializationTypeDefinition(InputModelType inputModel, ModelProvider m protected override CSharpType? BuildBaseType() => _model.BaseType; + protected override IReadOnlyList BuildHelperDependencyTypes() => _rawDataField != null || _additionalProperties.Value.Length > 0 + ? [ScmCodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType] + : []; + protected override SuppressionStatement[] BuildDisabledFileWarnings() { if (_model.CanonicalView.Properties.Any(p => ScmModelProvider.IsFileBinaryContentType(p.Type))) @@ -438,19 +433,17 @@ protected override CSharpType[] BuildImplements() if (_supportsJson) { interfaces.Add(_jsonModelTInterface); - var jsonModelObjectInterface = JsonModelObjectInterface; - if (jsonModelObjectInterface != null) + if (_jsonModelObjectInterface != null) { - interfaces.Add(jsonModelObjectInterface); + interfaces.Add(_jsonModelObjectInterface); } } else if (_supportsXml) { interfaces.Add(_persistableModelTInterface); - var persistableModelObjectInterface = PersistableModelObjectInterface; - if (persistableModelObjectInterface != null) + if (_persistableModelObjectInterface != null) { - interfaces.Add(persistableModelObjectInterface); + interfaces.Add(_persistableModelObjectInterface); } } @@ -480,7 +473,7 @@ internal MethodProvider BuildJsonModelWriteMethodObjectDeclaration() var castToT = This.CastTo(_jsonModelTInterface); return new MethodProvider ( - new MethodSignature(nameof(IJsonModel.Write), null, MethodSignatureModifiers.None, null, null, [_utf8JsonWriterParameter, _serializationOptionsParameter], ExplicitInterface: JsonModelObjectInterface), + new MethodSignature(nameof(IJsonModel.Write), null, MethodSignatureModifiers.None, null, null, [_utf8JsonWriterParameter, _serializationOptionsParameter], ExplicitInterface: _jsonModelObjectInterface), castToT.Invoke(nameof(IJsonModel.Write), [_utf8JsonWriterParameter, _serializationOptionsParameter]), this ); @@ -495,7 +488,7 @@ internal MethodProvider BuildJsonModelCreateMethodObjectDeclaration() var castToT = This.CastTo(_jsonModelTInterface); return new MethodProvider ( - new MethodSignature(nameof(IJsonModel.Create), null, MethodSignatureModifiers.None, typeof(object), null, [_utf8JsonReaderParameter, _serializationOptionsParameter], ExplicitInterface: JsonModelObjectInterface), + new MethodSignature(nameof(IJsonModel.Create), null, MethodSignatureModifiers.None, typeof(object), null, [_utf8JsonReaderParameter, _serializationOptionsParameter], ExplicitInterface: _jsonModelObjectInterface), castToT.Invoke(nameof(IJsonModel.Create), [_utf8JsonReaderParameter.AsArgument(), _serializationOptionsParameter]), this ); @@ -511,7 +504,7 @@ internal MethodProvider BuildPersistableModelWriteMethodObjectDeclaration() var returnType = typeof(BinaryData); return new MethodProvider ( - new MethodSignature(nameof(IPersistableModel.Write), null, MethodSignatureModifiers.None, returnType, null, [_serializationOptionsParameter], ExplicitInterface: PersistableModelObjectInterface), + new MethodSignature(nameof(IPersistableModel.Write), null, MethodSignatureModifiers.None, returnType, null, [_serializationOptionsParameter], ExplicitInterface: _persistableModelObjectInterface), castToT.Invoke(nameof(IPersistableModel.Write), [_serializationOptionsParameter]), this ); @@ -527,7 +520,7 @@ internal MethodProvider BuildPersistableModelCreateMethodObjectDeclaration() var returnType = typeof(object); return new MethodProvider ( - new MethodSignature(nameof(IPersistableModel.Create), null, MethodSignatureModifiers.None, returnType, null, [_dataParameter, _serializationOptionsParameter], ExplicitInterface: PersistableModelObjectInterface), + new MethodSignature(nameof(IPersistableModel.Create), null, MethodSignatureModifiers.None, returnType, null, [_dataParameter, _serializationOptionsParameter], ExplicitInterface: _persistableModelObjectInterface), castToT.Invoke(nameof(IPersistableModel.Create), [_dataParameter, _serializationOptionsParameter]), this ); @@ -541,7 +534,7 @@ internal MethodProvider BuildJsonModelWriteCoreMethod() MethodSignatureModifiers modifiers = _isStruct ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (ShouldOverrideMethods) + if (_shouldOverrideMethods) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -563,7 +556,7 @@ internal MethodProvider BuildPersistableModelWriteCoreMethod() ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (ShouldOverrideMethods && !ShouldSkipSerializationMethodOverrides) + if (_shouldOverrideMethods && !_shouldSkipDerivedSerializationMethodOverrides) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -587,7 +580,7 @@ internal MethodProvider BuildPersistableModelCreateCoreMethod() ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (ShouldOverrideMethods && !ShouldSkipSerializationMethodOverrides) + if (_shouldOverrideMethods && !_shouldSkipDerivedSerializationMethodOverrides) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -635,7 +628,7 @@ internal MethodProvider BuildJsonModelCreateCoreMethod() ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (ShouldOverrideMethods && !ShouldSkipSerializationMethodOverrides) + if (_shouldOverrideMethods && !_shouldSkipDerivedSerializationMethodOverrides) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -807,7 +800,7 @@ internal MethodProvider BuildPersistableModelGetFormatFromOptionsObjectDeclarati // string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => ((IPersistableModel)this).GetFormatFromOptions(options); return new MethodProvider ( - new MethodSignature(nameof(IPersistableModel.GetFormatFromOptions), null, MethodSignatureModifiers.None, typeof(string), null, [_serializationOptionsParameter], ExplicitInterface: PersistableModelObjectInterface), + new MethodSignature(nameof(IPersistableModel.GetFormatFromOptions), null, MethodSignatureModifiers.None, typeof(string), null, [_serializationOptionsParameter], ExplicitInterface: _persistableModelObjectInterface), castToT.Invoke(nameof(IPersistableModel.GetFormatFromOptions), [_serializationOptionsParameter]), this ); @@ -1066,7 +1059,7 @@ private MethodBodyStatement[] BuildPersistableModelCreateCoreMethodBody() private MethodBodyStatement CallBaseJsonModelWriteCore(bool isDynamicModelWithNonDynamicBase) { // base.() - bool callBaseWriteMethod = ShouldOverrideMethods + bool callBaseWriteMethod = _shouldOverrideMethods && (_jsonPatchProperty is null || !isDynamicModelWithNonDynamicBase); return callBaseWriteMethod ? Base.Invoke(JsonModelWriteCoreMethodName, [_utf8JsonWriterParameter, _serializationOptionsParameter]).Terminate() diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs index 88cb97b16e7..696d654dd9c 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MultipartFormDataSerializationDefinition.cs @@ -51,6 +51,12 @@ protected override string BuildRelativeFilePath() return Path.Combine("src", "Generated", "Models", $"{Name}.Serialization.Multipart.cs"); } + protected override IReadOnlyList BuildHelperDependencyTypes() => _model.Properties.Any( + prop => prop.WireInfo != null && !prop.WireInfo.IsRequired && + (prop.Type is { IsCollection: true, IsReadOnlyMemory: false } || prop.Type.IsDictionary)) + ? [ScmCodeModelGenerator.Instance.TypeFactory.OptionalType] + : []; + protected override SuppressionStatement[] BuildDisabledFileWarnings() => [new SuppressionStatement(null, Literal(ScmModelProvider.FileBinaryContentDiagnosticId), ScmModelProvider.ScmEvaluationTypeSuppressionJustification)]; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs index f00d78571d0..0e07cf120b2 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs @@ -78,6 +78,44 @@ protected override FieldProvider[] BuildFields() return [.. pipelineMessage20xClassifiersFields]; } + protected override IReadOnlyList BuildHelperDependencyTypes() + { + var dependencies = new List { new ClientUriBuilderDefinition().Type }; + foreach (var serviceMethod in _inputClient.Methods) + { + foreach (var parameter in serviceMethod.Operation.Parameters) + { + if (IsContentTypeParameter(parameter) || + parameter is not InputHeaderParameter and not InputQueryParameter) + { + continue; + } + + var type = ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(parameter.Type); + if (type?.IsDictionary == true) + { + AddDependency(dependencies, ScmCodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType); + } + else if (type?.IsCollection == true) + { + AddDependency(dependencies, ScmCodeModelGenerator.Instance.TypeFactory.ListInitializationType); + } + } + } + + return dependencies; + } + + private static void AddDependency(List dependencies, CSharpType dependency) + { + if (!dependencies.Any(existing => + existing.Name == dependency.Name && + existing.Namespace == dependency.Namespace)) + { + dependencies.Add(dependency); + } + } + protected override ScmMethodProvider[] BuildMethods() { List methods = new List(); @@ -549,18 +587,6 @@ private static MethodBodyStatement BuildAppendQueryStatement( { if (paramType?.IsCollection != true) { - // A model-typed query parameter marked with `explode` must be expanded into one query - // entry per property (RFC 6570 form explode, e.g. `?field=status&value=active`) rather - // than serialized via the object's ToString (which previously produced the type name). - if (inputQueryParameter.Explode && inputQueryParameter.Type is InputModelType inputModel) - { - var explodeStatement = BuildExplodeModelQueryStatement(uri, inputModel, valueExpression); - if (explodeStatement != null) - { - return explodeStatement; - } - } - var toStringExpression = GetQueryParameterStringExpression(paramType, valueExpression, serializationFormat); return uri.AppendQuery(Literal(inputQueryParameter.SerializedName), toStringExpression, true).Terminate(); } @@ -617,70 +643,6 @@ private static MethodBodyStatement BuildAppendQueryStatement( } } - /// - /// Builds the statements for a model-typed query parameter that uses form-style `explode`. - /// Each (simple) property of the model is emitted as its own query entry using the property's - /// wire name (RFC 6570 form explode, e.g. ?field=status&value=active). - /// Returns null when the model contains a property that is not a simple scalar/enum - /// (e.g. a nested object or a collection), in which case the caller falls back to the default - /// handling. Nested/complex expansion is tracked separately (see issue #11123). - /// - private static MethodBodyStatement? BuildExplodeModelQueryStatement( - ScopedApi uri, - InputModelType inputModel, - ValueExpression valueExpression) - { - var modelProvider = ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(inputModel); - if (modelProvider is null) - { - return null; - } - - var properties = modelProvider.CanonicalView.Properties; - if (properties.Count == 0) - { - return null; - } - - // Only expand when every property is a simple scalar or enum. Nested objects and - // collections are not defined by RFC 6570 form explode and require a separate design - // decision, so we fall back to the default handling for those. - foreach (var property in properties) - { - if (property.WireInfo is null || - property.Type.IsCollection || - (!property.Type.IsFrameworkType && !property.Type.IsEnum)) - { - return null; - } - } - - var statements = new List(); - foreach (var property in properties) - { - var propertyAccess = valueExpression.Property(property.Name); - var propertyType = property.Type; - - ValueExpression convertedValue = propertyType.IsEnum - ? propertyType.ToSerial(propertyAccess).ConvertToString() - : GetQueryParameterStringExpression(propertyType, propertyAccess, property.SerializationFormat); - - MethodBodyStatement appendStatement = - uri.AppendQuery(Literal(property.WireInfo!.SerializedName), convertedValue, true).Terminate(); - - if (!property.WireInfo.IsRequired || - propertyType.IsNullable || - (propertyType is { IsValueType: false, IsFrameworkType: true } && propertyType.FrameworkType != typeof(string))) - { - appendStatement = BuildQueryOrHeaderOrPathParameterNullCheck(propertyType, propertyAccess, appendStatement); - } - - statements.Add(appendStatement); - } - - return statements; - } - private static IfStatement BuildQueryOrHeaderOrPathParameterNullCheck( CSharpType? parameterType, ValueExpression valueExpression, @@ -919,7 +881,9 @@ private static void AppendLiteralSegment(ScopedApi uri, string literal, List paramMap, InputOperation operation, InputParameter inputParam, out CSharpType? type, out SerializationFormat? serializationFormat, out ValueExpression? valueExpression) { - type = ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(inputParam.Type); + type = IsGeneratedContentTypeMethodParameter(inputParam) + ? null + : ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(inputParam.Type); serializationFormat = null; if (inputParam.IsApiVersion && ClientProvider.IsMultiServiceClient) @@ -1208,7 +1172,10 @@ internal static List GetMethodParameters( // when one was already published. UpdateParameterNameWithBackCompat(inputParam, inputParam.Name, client.BackCompatProvider, serviceMethod); - ParameterProvider? parameter = ScmCodeModelGenerator.Instance.TypeFactory.CreateParameter(inputParam)?.ToPublicInputParameter(); + ParameterProvider? parameter = IsContentTypeParameter(inputParam) && + methodType is ScmMethodKind.Protocol or ScmMethodKind.CreateRequest + ? CreateContentTypeParameter(inputParam) + : ScmCodeModelGenerator.Instance.TypeFactory.CreateParameter(inputParam)?.ToPublicInputParameter(); if (parameter is null) { continue; @@ -1249,7 +1216,7 @@ internal static List GetMethodParameters( break; case ParameterLocation.Query: case ParameterLocation.Header: - if (inputParam is InputHeaderParameter { IsContentType: true } + if (IsContentTypeParameter(inputParam) && !HasContentTypeBeforeBodyInLastContract(serviceMethod.Name, client.BackCompatProvider)) { sortedParams.Add(contentType++, parameter); @@ -1292,12 +1259,25 @@ internal static List GetMethodParameters( return [.. sortedParams.Values]; } + private static ParameterProvider CreateContentTypeParameter(InputParameter inputParam) + { + var type = new CSharpType(typeof(string), isNullable: !inputParam.IsRequired); + return new ParameterProvider( + inputParam.Name, + DocHelpers.GetFormattableDescription(inputParam.Summary, inputParam.Doc) ?? FormattableStringHelpers.Empty, + type, + defaultValue: inputParam.IsRequired ? null : Default, + location: ParameterLocation.Header, + wireInfo: new WireInformation(SerializationFormat.Default, inputParam.SerializedName), + validation: inputParam.IsRequired ? ParameterValidationType.AssertNotNullOrEmpty : ParameterValidationType.None, + inputParameter: inputParam); + } + private static bool HasLiteralContentTypeHeader(InputOperation operation) { foreach (var p in operation.Parameters) { - if (p is InputHeaderParameter { IsContentType: true } header - && header.Type is InputLiteralType) + if (p is InputHeaderParameter { IsContentType: true } && p.Type is InputLiteralType) { return true; } @@ -1305,6 +1285,15 @@ private static bool HasLiteralContentTypeHeader(InputOperation operation) return false; } + private static bool IsContentTypeParameter(InputParameter parameter) => + parameter is InputHeaderParameter { IsContentType: true } || + parameter is InputMethodParameter { Location: InputRequestLocation.Header } && + string.Equals(parameter.SerializedName, "Content-Type", StringComparison.OrdinalIgnoreCase); + + private static bool IsGeneratedContentTypeMethodParameter(InputParameter parameter) => + parameter is InputMethodParameter { Location: InputRequestLocation.Header } && + string.Equals(parameter.SerializedName, "Content-Type", StringComparison.OrdinalIgnoreCase); + /// /// Checks if the last contract view contains a method matching the given name where /// a "contentType" parameter appears before the body ("content") parameter. diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/SerializationFormatDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/SerializationFormatDefinition.cs index af294640060..e902afd7c79 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/SerializationFormatDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/SerializationFormatDefinition.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using System.Collections.Generic; using System.IO; using System.Linq; @@ -45,6 +46,7 @@ protected override TypeSignatureModifiers BuildDeclarationModifiers() protected override string BuildRelativeFilePath() => Path.Combine("src", "Generated", "Internal", $"{Name}.cs"); protected override string BuildName() => "SerializationFormat"; + protected override FormattableString BuildDescription() => $"The serialization format."; protected override TypeProvider[] BuildSerializationProviders() => []; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Snippets/HttpRequestApiSnippets.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Snippets/HttpRequestApiSnippets.cs index 588a2094b12..fdd8f72fba0 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Snippets/HttpRequestApiSnippets.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Snippets/HttpRequestApiSnippets.cs @@ -26,7 +26,9 @@ public static MethodBodyStatement SetContent(this ScopedApi pip public static MethodBodyStatement SetHeaderDelimited(this HttpRequestApi pipelineRequest, string name, ValueExpression value, ValueExpression delimiter, ValueExpression? format = null) { ValueExpression[] parameters = format != null ? [Literal(name), value, delimiter, format] : [Literal(name), value, delimiter]; - return pipelineRequest.Property(nameof(PipelineRequest.Headers)).Invoke("SetDelimited", parameters).Terminate(); + return pipelineRequest.Property(nameof(PipelineRequest.Headers)) + .Invoke("SetDelimited", parameters, typeArguments: null, callAsAsync: false, extensionType: new PipelineRequestHeadersExtensionsDefinition().Type) + .Terminate(); } } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs new file mode 100644 index 00000000000..2cadcbe5cb0 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/PostProcessing/ClientBodyDependencyPostProcessingTests.cs @@ -0,0 +1,881 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.TypeSpec.Generator.Input; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Tests.Common; +using NUnit.Framework; + +namespace Microsoft.TypeSpec.Generator.ClientModel.Tests.PostProcessing +{ + public class ClientBodyDependencyPostProcessingTests + { + [Test] + public async Task OperationBodyParameterModelDoesNotBecomePublic() + { + var requestModel = InputFactory.Model("RequestBody"); + var parameter = InputFactory.BodyParameter("body", requestModel, isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Create", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertInternalModels([requestModel], [client], ["RequestBody"]); + } + + [Test] + public async Task OperationResponseBodyModelRemainsPublicAsRootOutputModel() + { + var responseModel = InputFactory.Model("ResponseBody"); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertPublicModels([responseModel], [client], ["ResponseBody"]); + } + + [Test] + public async Task InternalModelReferencedByPublicModelPropertyIsPublicized() + { + var dependencyModel = InputFactory.Model("DependencyModel", access: "internal"); + var responseModel = InputFactory.Model( + "ResponseBody", + properties: [InputFactory.Property("Dependency", dependencyModel)]); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(responseModel, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertPublicModels([responseModel, dependencyModel], [client], ["ResponseBody", "DependencyModel"]); + } + + [Test] + public async Task InternalModelReferencedByPublicNonRootCollectionPropertyIsPublicized() + { + var dependencyModel = InputFactory.Model("DependencyModel", access: "internal"); + var responseModel = InputFactory.Model( + "ResponseBody", + properties: [InputFactory.Property("Dependencies", InputFactory.Array(dependencyModel))]); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(responseModel, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [responseModel, dependencyModel], + clients: [client], + customFiles: [ + (Path.Combine("src", "Generated", "SampleModelFactory.cs"), """ + using System.Collections.Generic; + using Sample.Models; + + namespace Sample; + + public static partial class SampleModelFactory + { + public static ResponseBody ResponseBody(IEnumerable dependencies = default) => null; + } + """) + ], + expectedFiles: [], + publicModelNames: ["ResponseBody", "DependencyModel"], + configureGenerator: () => + { + var responseProvider = CodeModelGenerator.Instance.OutputLibrary.TypeProviders.Single(provider => provider.Name == "ResponseBody"); + var dependencyProvider = CodeModelGenerator.Instance.OutputLibrary.TypeProviders.Single(provider => provider.Name == "DependencyModel"); + CodeModelGenerator.Instance.AddTypeToKeep(responseProvider, isRoot: false); + CodeModelGenerator.Instance.AddTypeToKeep(dependencyProvider, isRoot: false); + }); + } + + [Test] + public async Task CustomInternalBoundaryInternalizesPublicNonRootModel() + { + var customInternalModel = InputFactory.Model("CustomInternalModel"); + var publicWrapper = InputFactory.Model( + "PublicWrapper", + properties: [InputFactory.Property("CustomInternal", customInternalModel)]); + + var outputPath = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + Directory.CreateDirectory(outputPath); + try + { + var customPath = Path.Combine(outputPath, "src", "Custom", "CustomInternalModel.cs"); + Directory.CreateDirectory(Path.GetDirectoryName(customPath)!); + File.WriteAllText(customPath, """ + using Microsoft.TypeSpec.Generator.Customizations; + + namespace Sample.Models; + + [CodeGenType("CustomInternalModel")] + internal partial class CustomInternalModel + { + } + """); + var modelFactoryPath = Path.Combine(outputPath, "src", "Generated", "SampleModelFactory.cs"); + Directory.CreateDirectory(Path.GetDirectoryName(modelFactoryPath)!); + File.WriteAllText(modelFactoryPath, """ + using Sample.Models; + + namespace Sample; + + public static partial class SampleModelFactory + { + public static PublicWrapper PublicWrapper(CustomInternalModel customInternal = default) => null; + } + """); + + await MockHelpers.LoadMockGeneratorAsync( + inputModels: () => [publicWrapper, customInternalModel], + configuration: """{ "package-name": "Sample", "disable-xml-docs": true }""", + outputPath: outputPath); + + var publicWrapperProvider = CodeModelGenerator.Instance.OutputLibrary.TypeProviders.Single(provider => provider.Name == "PublicWrapper"); + var customInternalProvider = CodeModelGenerator.Instance.OutputLibrary.TypeProviders.Single(provider => provider.Name == "CustomInternalModel"); + CodeModelGenerator.Instance.AddTypeToKeep(publicWrapperProvider, isRoot: false); + CodeModelGenerator.Instance.AddTypeToKeep(customInternalProvider, isRoot: false); + + ProviderReferenceMapAnalyzer.ApplyPreWriteAccessibility(CodeModelGenerator.Instance.OutputLibrary.TypeProviders); + + Assert.IsTrue(publicWrapperProvider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal), "PublicWrapper should be internalized when its public surface exposes a custom/internal type."); + Assert.IsTrue(customInternalProvider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal), "CustomInternalModel should remain internal."); + } + finally + { + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); + if (Directory.Exists(outputPath)) + { + Directory.Delete(outputPath, recursive: true); + } + } + } + + [Test] + public async Task OperationResponseBodyModelIsRemovedWhenNotOtherwiseReferenced() + { + var metadataOnlyModel = InputFactory.Model("MetadataOnlyResponse"); + var operation = InputFactory.Operation( + "Get", + responses: [ + InputFactory.OperationResponse(bodytype: InputPrimitiveType.String), + new InputOperationResponse([202], metadataOnlyModel, [], isErrorResponse: false, ["application/json"]) + ]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [metadataOnlyModel], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponse.cs"), + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponse.Serialization.cs") + ]); + } + + [Test] + public async Task EmptyCustomPartialModelIsKept() + { + var customizedModel = InputFactory.Model("CustomizedModel"); + + await GenerateAndAssertFiles( + enums: [], + models: [customizedModel], + clients: [], + customFiles: [( + Path.Combine("src", "CustomizedModel.cs"), + """ + namespace Sample.Models + { + public partial class CustomizedModel + { + } + } + """)], + expectedFiles: [ + Path.Combine("src", "Generated", "Models", "CustomizedModel.cs"), + Path.Combine("src", "Generated", "Models", "CustomizedModel.Serialization.cs") + ]); + } + + [Test] + public async Task InternalAdditionalRootModelIsRemovedWhenNotOtherwiseReferenced() + { + var metadataOnlyModel = InputFactory.Model("MetadataOnlyResponse", access: "internal"); + var operation = InputFactory.Operation( + "Get", + responses: [ + InputFactory.OperationResponse(bodytype: InputPrimitiveType.String), + new InputOperationResponse([202], metadataOnlyModel, [], isErrorResponse: false, ["application/json"]) + ]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [metadataOnlyModel], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponse.cs"), + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponse.Serialization.cs") + ], + configureGenerator: () => + { + var provider = CodeModelGenerator.Instance.OutputLibrary.TypeProviders.Single(provider => provider.Name == "MetadataOnlyResponse"); + CodeModelGenerator.Instance.AddTypeToKeep(provider); + }); + } + + [Test] + public async Task AdditionalRootEnumIsRemovedWhenNotOtherwiseReferenced() + { + var metadataOnlyEnum = InputFactory.StringEnum( + "MetadataOnlyResponseKind", + [("Accepted", "accepted")]); + var operation = InputFactory.Operation( + "Get", + responses: [ + InputFactory.OperationResponse(bodytype: InputPrimitiveType.String), + new InputOperationResponse([202], metadataOnlyEnum, [], isErrorResponse: false, ["application/json"]) + ]); + var method = InputFactory.BasicServiceMethod( + "Get", + operation, + response: InputFactory.ServiceMethodResponse(InputPrimitiveType.String, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [metadataOnlyEnum], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponseKind.cs"), + Path.Combine("src", "Generated", "Models", "MetadataOnlyResponseKind.Serialization.cs") + ]); + } + + [Test] + public async Task ContentTypeHeaderEnumIsRemovedWhenNotOtherwiseReferenced() + { + var contentTypeEnum = InputFactory.StringEnum( + "UpdateSnapshotRequestContentType", + [ + ("ApplicationMergePatchJson", "application/merge-patch+json"), + ("ApplicationJson", "application/json") + ]); + var contentTypeParameter = InputFactory.MethodParameter( + "contentType", + InputFactory.Union([contentTypeEnum], "contentType"), + isRequired: true, + location: InputRequestLocation.Header, + serializedName: "Content-Type"); + var operation = InputFactory.Operation( + "UpdateSnapshot", + parameters: [contentTypeParameter], + httpMethod: "PATCH", + generateConvenienceMethod: false); + var method = InputFactory.BasicServiceMethod( + "UpdateSnapshot", + operation, + parameters: [ + InputFactory.MethodParameter("name", InputPrimitiveType.String, isRequired: true, location: InputRequestLocation.Path), + contentTypeParameter, + InputFactory.MethodParameter("content", InputPrimitiveType.Base64, isRequired: true) + ]); + var client = InputFactory.Client("ConfigurationClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [contentTypeEnum], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "UpdateSnapshotRequestContentType.cs"), + Path.Combine("src", "Generated", "Models", "UpdateSnapshotRequestContentType.Serialization.cs") + ], + configureGenerator: () => + CodeModelGenerator.Instance.TypeFactory.CreateCSharpType(InputFactory.Union([contentTypeEnum], "contentType"))); + } + + [Test] + public async Task PublicEnumIsRemovedWhenNotOtherwiseReferenced() + { + var metadataOnlyEnum = InputFactory.StringEnum( + "MetadataOnlyKind", + [("One", "one")]); + + await GenerateAndAssertFiles( + enums: [metadataOnlyEnum], + models: [], + clients: [], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Models", "MetadataOnlyKind.cs"), + Path.Combine("src", "Generated", "Models", "MetadataOnlyKind.Serialization.cs") + ]); + } + + [Test] + public async Task ContentTypeHeaderEnumReferencedByCustomSuppressionIsKept() + { + var contentTypeEnum = InputFactory.StringEnum( + "PutKeyValueRequestContentType", + [("ApplicationJson", "application/json")], + isExtensible: true); + var contentTypeParameter = InputFactory.HeaderParameter( + "contentType", + InputFactory.Union([contentTypeEnum], "contentType"), + isRequired: true, + isContentType: true, + serializedName: "Content-Type"); + var operation = InputFactory.Operation( + "SetConfigurationSettingInternal", + parameters: [contentTypeParameter], + httpMethod: "PUT"); + var method = InputFactory.BasicServiceMethod("SetConfigurationSettingInternal", operation); + var client = InputFactory.Client("ConfigurationClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [contentTypeEnum], + models: [], + clients: [client], + customFiles: [ + (Path.Combine("src", "PutKeyValueRequestContentType.cs"), """ + namespace Sample.Models; + + internal readonly partial struct PutKeyValueRequestContentType + { + public static PutKeyValueRequestContentType ApplicationJson { get; } = new PutKeyValueRequestContentType("application/json"); + } + """) + ], + expectedFiles: [ + Path.Combine("src", "Generated", "Models", "PutKeyValueRequestContentType.cs") + ]); + } + + [Test] + public async Task ContentTypeHeaderEnumReferencedOnlyByCustomSuppressionAttributeIsKept() + { + var contentTypeEnum = InputFactory.StringEnum( + "CreateSnapshotRequestContentType", + [("ApplicationJson", "application/json")], + isExtensible: true); + var contentTypeParameter = InputFactory.MethodParameter( + "contentType", + InputFactory.Union([contentTypeEnum], "contentType"), + isRequired: true, + location: InputRequestLocation.Header, + serializedName: "Content-Type"); + var operation = InputFactory.Operation( + "CreateSnapshot", + parameters: [contentTypeParameter], + httpMethod: "PUT", + generateConvenienceMethod: false); + var method = InputFactory.BasicServiceMethod( + "CreateSnapshot", + operation, + parameters: [ + InputFactory.MethodParameter("name", InputPrimitiveType.String, isRequired: true, location: InputRequestLocation.Path), + contentTypeParameter, + InputFactory.MethodParameter("content", InputPrimitiveType.Base64, isRequired: true) + ]); + var client = InputFactory.Client("ConfigurationClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [contentTypeEnum], + models: [], + clients: [client], + customFiles: [ + (Path.Combine("src", "ConfigurationClient.cs"), """ + namespace Sample; + + [Microsoft.TypeSpec.Generator.Customizations.CodeGenType("ConfigurationClient")] + [Microsoft.TypeSpec.Generator.Customizations.CodeGenSuppress("CreateSnapshot", typeof(string), typeof(CreateSnapshotRequestContentType))] + public partial class ConfigurationClient + { + } + """) + ], + expectedFiles: [ + Path.Combine("src", "Generated", "Models", "CreateSnapshotRequestContentType.cs") + ]); + } + + [Test] + public async Task NestedBodyModelGraphDoesNotBecomePublic() + { + var nestedModel = InputFactory.Model("NestedToolParameter"); + var toolModel = InputFactory.Model( + "ToolConfig", + properties: [InputFactory.Property("Parameter", nestedModel)]); + var parameter = InputFactory.BodyParameter("tool", toolModel, isRequired: true); + var operation = InputFactory.Operation("Configure", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Configure", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertInternalModels([toolModel, nestedModel], [client], ["ToolConfig", "NestedToolParameter"]); + } + + [Test] + public async Task NonDiscriminatorDerivedBodyModelDoesNotBecomePublicFromPublicBase() + { + var baseTool = InputFactory.Model("BaseTool"); + var concreteTool = InputFactory.Model( + "ConcreteTool", + properties: [InputFactory.Property("Name", InputPrimitiveType.String)], + baseModel: baseTool); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: baseTool)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(baseTool, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertMixedModels( + [baseTool, concreteTool], + [client], + publicModelNames: ["BaseTool"], + internalModelNames: ["ConcreteTool"]); + } + + [Test] + public async Task PublicModelSignatureDependencyIsPromotedToPublic() + { + var internalDependency = InputFactory.Model("InternalDependency", access: "internal"); + var responseModel = InputFactory.Model( + "ResponseBody", + properties: [InputFactory.Property("Dependency", internalDependency)]); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(responseModel, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertPublicModels([responseModel, internalDependency], [client], ["ResponseBody", "InternalDependency"]); + } + + [Test] + public async Task AzureClientPublicMethodSignatureReferencesStayPublic() + { + var signatureModel = InputFactory.Model("SignatureModel", @namespace: "Azure.Sample.Models"); + var methodParameter = InputFactory.MethodParameter("signature", signatureModel, isRequired: true); + var operation = InputFactory.Operation( + "Create", + parameters: [InputFactory.BodyParameter("signature", signatureModel, isRequired: true)], + httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Create", operation, parameters: [methodParameter]); + var client = InputFactory.Client("SampleClient", clientNamespace: "Azure.Sample", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [signatureModel], + clients: [client], + customFiles: [], + expectedFiles: [], + publicModelNames: ["SignatureModel"], + packageName: "Azure.Sample"); + } + + [Test] + public async Task BasePreservedDerivedModelTraversesTransitiveDependencies() + { + var transitiveDependency = InputFactory.Model("TransitiveDependency"); + var dependency = InputFactory.Model( + "DerivedDependency", + properties: [InputFactory.Property("Transitive", transitiveDependency)]); + var baseModel = InputFactory.Model("BaseResult"); + var derivedModel = InputFactory.Model( + "DerivedResult", + properties: [InputFactory.Property("Dependency", dependency)], + baseModel: baseModel); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: baseModel)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(baseModel, [])); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [baseModel, derivedModel, dependency, transitiveDependency], + clients: [client], + customFiles: [], + expectedFiles: [], + publicModelNames: ["BaseResult"], + internalModelNames: ["DerivedResult", "DerivedDependency", "TransitiveDependency"]); + } + + [Test] + public async Task PublicCustomCodeArraySignatureReferencesStayPublic() + { + var generatedModel = InputFactory.Model("GeneratedModel"); + + await GenerateAndAssertFiles( + enums: [], + models: [generatedModel], + clients: [], + customFiles: [ + (Path.Combine("src", "PublicCustomApi.cs"), """ + using Sample.Models; + + namespace Sample; + + public partial class PublicCustomApi + { + public GeneratedModel[] Items { get; } = System.Array.Empty(); + } + """) + ], + expectedFiles: [], + publicModelNames: ["GeneratedModel"]); + } + + [Test] + public async Task GeneratedRequestHeaderSetDelimitedReferenceKeepsExtensions() + { + var header = InputFactory.HeaderParameter("x-ms-custom", InputFactory.Array(InputPrimitiveType.String), isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [header]); + var method = InputFactory.BasicServiceMethod("Create", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [Path.Combine("src", "Generated", "Internal", "PipelineRequestHeadersExtensions.cs")]); + } + + [Test] + public async Task BinaryDataBodyParameterDoesNotKeepBinaryContentHelpers() + { + var parameter = InputFactory.BodyParameter( + "content", + InputPrimitiveType.Base64, + isRequired: true, + contentTypes: ["application/octet-stream"], + defaultContentType: "application/octet-stream"); + var operation = InputFactory.Operation("Upload", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod( + "Upload", + operation, + parameters: [InputFactory.MethodParameter("content", InputPrimitiveType.Base64, isRequired: true)]); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [], + unexpectedFiles: [ + Path.Combine("src", "Generated", "Internal", "BinaryContentHelper.cs"), + Path.Combine("src", "Generated", "Internal", "Utf8JsonBinaryContent.cs") + ]); + } + + [Test] + public async Task CollectionBodyParameterKeepsBinaryContentHelpers() + { + var parameter = InputFactory.BodyParameter("items", InputFactory.Array(InputPrimitiveType.String), isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod( + "Create", + operation, + parameters: [InputFactory.MethodParameter("items", InputFactory.Array(InputPrimitiveType.String), isRequired: true)]); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [], + clients: [client], + customFiles: [], + expectedFiles: [ + Path.Combine("src", "Generated", "Internal", "BinaryContentHelper.cs"), + Path.Combine("src", "Generated", "Internal", "Utf8JsonBinaryContent.cs") + ]); + } + + [Test] + public async Task CustomOnlyRequestHeaderSetDelimitedReferenceKeepsExtensions() + { + await GenerateAndAssertFiles( + enums: [], + models: [], + clients: [], + customFiles: [ + (Path.Combine("src", "CustomHeaders.cs"), """ + using System.ClientModel.Primitives; + + namespace Sample; + + public static class CustomHeaders + { + public static void Add(PipelineRequestHeaders headers, string[] values) + => headers.SetDelimited("x-ms-custom", values, ","); + } + """) + ], + expectedFiles: [Path.Combine("src", "Generated", "Internal", "PipelineRequestHeadersExtensions.cs")]); + } + + [Test] + public async Task CustomizedEnumSerializationProviderIsKeptWhenModelSerializationUsesEnum() + { + var statusEnum = InputFactory.StringEnum( + "Status", + [("Succeeded", "succeeded"), ("Failed", "failed")], + clientNamespace: "Sample"); + var resultModel = InputFactory.Model( + "OperationResult", + properties: [InputFactory.Property("Status", statusEnum, isRequired: true)], + @namespace: "Sample"); + var operation = InputFactory.Operation("Get", responses: [InputFactory.OperationResponse(bodytype: resultModel)]); + var method = InputFactory.BasicServiceMethod("Get", operation, response: InputFactory.ServiceMethodResponse(resultModel, [])); + var client = InputFactory.Client("TestClient", methods: [method], clientNamespace: "Sample"); + + await GenerateAndAssertFiles( + enums: [statusEnum], + models: [resultModel], + clients: [client], + customFiles: [ + (Path.Combine("src", "Custom", "Status.cs"), """ + namespace Sample; + + [CodeGenType("Status")] + public enum Status + { + Succeeded, + Failed + } + """) + ], + expectedFiles: [Path.Combine("src", "Generated", "Models", "Status.Serialization.cs")]); + } + + [Test] + public async Task CustomModelFactoryPartialDoesNotKeepBodyOnlyModelPublic() + { + var requestModel = InputFactory.Model("RequestBody"); + var parameter = InputFactory.BodyParameter("body", requestModel, isRequired: true); + var operation = InputFactory.Operation("Create", parameters: [parameter], httpMethod: "POST"); + var method = InputFactory.BasicServiceMethod("Create", operation); + var client = InputFactory.Client("TestClient", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [requestModel], + clients: [client], + customFiles: [ + (Path.Combine("src", "SampleModelFactory.cs"), """ + namespace Sample; + + [Microsoft.TypeSpec.Generator.Customizations.CodeGenType("SampleModelFactory")] + public static partial class SampleModelFactory + { + } + """) + ], + expectedFiles: [], + internalModelNames: ["RequestBody"]); + } + + [Test] + public async Task InternalCustomClientPartialOverridesLastContractPublicClient() + { + var responseModel = InputFactory.Model("CompactResource"); + var operation = InputFactory.Operation("Compact", responses: [InputFactory.OperationResponse(bodytype: responseModel)]); + var method = InputFactory.BasicServiceMethod("Compact", operation, response: InputFactory.ServiceMethodResponse(responseModel, [])); + var client = InputFactory.Client("Responses", methods: [method]); + + await GenerateAndAssertFiles( + enums: [], + models: [responseModel], + clients: [client], + customFiles: [ + (Path.Combine("src", "Generated", "Responses.cs"), """ + namespace Sample; + + public partial class Responses + { + } + """), + (Path.Combine("src", "Custom", "Internal", "Responses.cs"), """ + namespace Sample; + + internal partial class Responses + { + } + """) + ], + expectedFiles: [], + internalModelNames: ["CompactResource"], + internalClientNames: ["Responses"]); + } + + private static async Task GenerateAndAssertInternalModels( + InputModelType[] models, + InputClient[] clients, + string[] modelNames) + => await GenerateAndAssertModels(models, clients, modelNames, shouldBePublic: false); + + private static async Task GenerateAndAssertPublicModels( + InputModelType[] models, + InputClient[] clients, + string[] modelNames) + => await GenerateAndAssertModels(models, clients, modelNames, shouldBePublic: true); + + private static async Task GenerateAndAssertMixedModels( + InputModelType[] models, + InputClient[] clients, + string[] publicModelNames, + string[] internalModelNames) + => await GenerateAndAssertModels(models, clients, publicModelNames, internalModelNames); + + private static async Task GenerateAndAssertModels( + InputModelType[] models, + InputClient[] clients, + string[] modelNames, + bool shouldBePublic) + => await GenerateAndAssertModels( + models, + clients, + shouldBePublic ? modelNames : [], + shouldBePublic ? [] : modelNames); + + private static async Task GenerateAndAssertModels( + InputModelType[] models, + InputClient[] clients, + string[] publicModelNames, + string[] internalModelNames) + { + await GenerateAndAssertFiles( + enums: [], + models: models, + clients: clients, + customFiles: [], + publicModelNames: publicModelNames, + internalModelNames: internalModelNames, + expectedFiles: []); + } + + private static async Task GenerateAndAssertFiles( + InputEnumType[] enums, + InputModelType[] models, + InputClient[] clients, + (string Path, string Content)[] customFiles, + string[] expectedFiles, + string[] unexpectedFiles = null!, + string[] publicModelNames = null!, + string[] internalModelNames = null!, + string[] internalClientNames = null!, + string packageName = "Sample", + Action? configureGenerator = null) + { + publicModelNames ??= []; + internalModelNames ??= []; + internalClientNames ??= []; + unexpectedFiles ??= []; + + var outputPath = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + Directory.CreateDirectory(outputPath); + try + { + foreach (var customFile in customFiles) + { + var customPath = Path.Combine(outputPath, customFile.Path); + Directory.CreateDirectory(Path.GetDirectoryName(customPath)!); + File.WriteAllText(customPath, customFile.Content); + } + + await MockHelpers.LoadMockGeneratorAsync( + inputEnums: () => enums, + inputModels: () => models, + clients: () => clients, + configuration: $$"""{ "package-name": "{{packageName}}", "disable-xml-docs": true }""", + outputPath: outputPath); + configureGenerator?.Invoke(); + + await new CSharpGen().ExecuteAsync(); + + foreach (var modelName in publicModelNames) + { + var modelPath = Path.Combine(outputPath, "src", "Generated", "Models", $"{modelName}.cs"); + Assert.IsTrue(File.Exists(modelPath), $"Expected generated model file '{modelPath}'."); + var text = File.ReadAllText(modelPath); + StringAssert.Contains($"public partial class {modelName}", text, $"{modelName} should be public."); + } + + foreach (var modelName in internalModelNames) + { + var modelPath = Path.Combine(outputPath, "src", "Generated", "Models", $"{modelName}.cs"); + Assert.IsTrue(File.Exists(modelPath), $"Expected generated model file '{modelPath}'."); + var text = File.ReadAllText(modelPath); + StringAssert.Contains($"internal partial class {modelName}", text, $"{modelName} should be internal."); + StringAssert.DoesNotContain($"public partial class {modelName}", text, $"{modelName} should not be public."); + } + + foreach (var clientName in internalClientNames) + { + var clientPath = Path.Combine(outputPath, "src", "Generated", $"{clientName}.cs"); + Assert.IsTrue(File.Exists(clientPath), $"Expected generated client file '{clientPath}'."); + var text = File.ReadAllText(clientPath); + StringAssert.Contains($"internal partial class {clientName}", text, $"{clientName} should be internal."); + StringAssert.DoesNotContain($"public partial class {clientName}", text, $"{clientName} should not be public."); + } + + var modelFactoryPath = Path.Combine(outputPath, "src", "Generated", "SampleModelFactory.cs"); + if (File.Exists(modelFactoryPath)) + { + var modelFactoryText = File.ReadAllText(modelFactoryPath); + foreach (var modelName in publicModelNames) + { + StringAssert.Contains($" {modelName}(", modelFactoryText, $"Model factory method for {modelName} should be generated."); + } + + foreach (var modelName in internalModelNames) + { + StringAssert.DoesNotContain($" {modelName}(", modelFactoryText, $"Model factory method for {modelName} should not be generated."); + } + } + + foreach (var expectedFile in expectedFiles) + { + var filePath = Path.Combine(outputPath, expectedFile); + Assert.IsTrue(File.Exists(filePath), $"Expected generated file '{filePath}'."); + } + + foreach (var unexpectedFile in unexpectedFiles) + { + var filePath = Path.Combine(outputPath, unexpectedFile); + Assert.IsFalse(File.Exists(filePath), $"Did not expect generated file '{filePath}'."); + } + } + finally + { + if (Directory.Exists(outputPath)) + { + Directory.Delete(outputPath, recursive: true); + } + } + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs index fcc90582416..503ed1ad68f 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs @@ -121,31 +121,6 @@ public void JsonModelWriteCore_IsOverride_WhenBaseIsRegularModel() "JsonModelWriteCore should be 'override' with regular base too"); } - [Test] - public void JsonModelWriteCore_IsOverride_WhenBaseProviderIsResolvedAfterSerialization() - { - var baseInputModel = InputFactory.Model("Resource"); - var derivedInputModel = InputFactory.Model("TrackedResource", properties: [InputFactory.Property("Location", InputPrimitiveType.String)]); - MockHelpers.LoadMockGenerator(inputModels: () => [baseInputModel, derivedInputModel]); - - var derived = new DelayedBaseModelProvider(derivedInputModel); - var serialization = new MrwSerializationTypeDefinition(derivedInputModel, derived); - - // The serialization provider can be constructed before later visitors/customization - // resolution make the base model provider available. - derived.BaseModel = new SystemObjectModelProvider(new CSharpType(typeof(object)), baseInputModel); - - var method = serialization.BuildJsonModelWriteCoreMethod(); - - Assert.AreEqual(derived.BaseModel.Type, derived.Type.BaseType, - "The generated model type should inherit the base resolved after serialization construction."); - Assert.AreEqual(derived.BaseModel.Type, serialization.Type.BaseType, - "The serialization type should inherit the same resolved base."); - Assert.IsTrue(method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Override), - "JsonModelWriteCore should evaluate BaseModelProvider when the method is built, not when serialization is constructed"); - Assert.IsFalse(method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Virtual)); - } - // ------------------------------------------------------------------- // PersistableModelWriteCore: 'virtual' with system base, 'override' with regular // (the framework base already implements this; derived model re-introduces it) @@ -353,14 +328,5 @@ FakeMrwBase IPersistableModel.Create(BinaryData data, ModelReaderWr string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; } - - private class DelayedBaseModelProvider(InputModelType inputModel) : ModelProvider(inputModel) - { - public ModelProvider? BaseModel { get; set; } - - protected override ModelProvider? BuildBaseModelProvider() => BaseModel; - - protected override CSharpType? BuildBaseType() => BaseModel?.Type; - } } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs index e8bcb2fa0fb..7e637e07363 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs @@ -738,38 +738,6 @@ public void ValidateGetResponseClassifiersThrowsWhenNoSuccess() Assert.Fail("Expected Exception to be thrown."); } - [Test] - public void TestBuildCreateRequestMethodWithExplodedModelQueryParameter() - { - var filterModel = InputFactory.Model( - "filterOptions", - properties: - [ - InputFactory.Property("field", InputPrimitiveType.String, isRequired: true), - InputFactory.Property("value", InputPrimitiveType.String, isRequired: true), - ]); - var operation = InputFactory.Operation( - "sampleOp", - parameters: [InputFactory.QueryParameter("filter", filterModel, isRequired: true, explode: true)]); - var client = InputFactory.Client( - "TestClient", - methods: [InputFactory.BasicServiceMethod("Test", operation)]); - var clientProvider = new ClientProvider(client); - var restClientProvider = new MockClientProvider(client, clientProvider); - - var method = restClientProvider.Methods.FirstOrDefault(m => m.Signature.Name == "CreateSampleOpRequest"); - Assert.IsNotNull(method); - var body = method!.BodyStatements!.ToDisplayString(); - - // A model-typed query parameter with `explode` is expanded into one query entry per - // property (RFC 6570 form explode) using each property's wire name, instead of serializing - // the whole object via ConvertToString (which produced the type name). - Assert.IsTrue(body.Contains("uri.AppendQuery(\"field\", filter.Field, true);"), body); - Assert.IsTrue(body.Contains("uri.AppendQuery(\"value\", filter.Value, true);"), body); - Assert.IsFalse(body.Contains("AppendQuery(\"filter\""), body); - Assert.IsFalse(body.Contains("ConvertToString(filter)"), body); - } - [Test] public void TestBuildCreateRequestMethodWithQueryParameters() { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs index 9148f659e43..6c6e743ad89 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/TestHelpers/MockHelpers.cs @@ -34,7 +34,8 @@ public static async Task> LoadMockGeneratorAsync( Func>? apiVersions = null, string? configuration = null, Func? createCSharpTypeCore = null, - Func? createCSharpTypeCoreFallback = null) + Func? createCSharpTypeCoreFallback = null, + string? outputPath = null) { var mockGenerator = LoadMockGenerator( inputLiterals: inputLiterals, @@ -44,13 +45,13 @@ public static async Task> LoadMockGeneratorAsync( apiVersions: apiVersions, configuration: configuration, createCSharpTypeCore: createCSharpTypeCore, - createCSharpTypeCoreFallback: createCSharpTypeCoreFallback); + createCSharpTypeCoreFallback: createCSharpTypeCoreFallback, + outputPath: outputPath); var compilationResult = compilation == null ? null : await compilation(); var lastContractCompilationResult = lastContractCompilation == null ? null : await lastContractCompilation(); - var sourceInputModel = new Mock(() => new SourceInputModel(compilationResult, lastContractCompilationResult)) { CallBase = true }; - mockGenerator.Setup(p => p.SourceInputModel).Returns(sourceInputModel.Object); + mockGenerator.SetupProperty(p => p.SourceInputModel, new SourceInputModel(compilationResult, lastContractCompilationResult)); return mockGenerator; } @@ -76,7 +77,8 @@ public static Mock LoadMockGenerator( Func? createOutputLibrary = null, bool includeXmlDocs = false, Func? createCSharpTypeCoreFallback = null, - Func? createModelCore = null) + Func? createModelCore = null, + string? outputPath = null) { IReadOnlyList inputNsApiVersions = apiVersions?.Invoke() ?? []; IReadOnlyList inputNsLiterals = inputLiterals?.Invoke() ?? []; @@ -150,7 +152,7 @@ public static Mock LoadMockGenerator( { configuration = "{\"disable-xml-docs\": false, \"package-name\": \"Sample.Namespace\"}"; } - object?[] parameters = [_configFilePath, configuration]; + object?[] parameters = [outputPath ?? _configFilePath, configuration]; var config = loadMethod?.Invoke(null, parameters); var mockGeneratorContext = new Mock(config!); var mockGeneratorInstance = new Mock(mockGeneratorContext.Object) { CallBase = true }; @@ -186,8 +188,7 @@ public static Mock LoadMockGenerator( mockGeneratorInstance.Setup(p => p.OutputLibrary).Returns(createOutputLibrary); } - var sourceInputModel = new Mock(() => new SourceInputModel(null, null)) { CallBase = true }; - mockGeneratorInstance.Setup(p => p.SourceInputModel).Returns(sourceInputModel.Object); + mockGeneratorInstance.SetupProperty(p => p.SourceInputModel, new SourceInputModel(null, null)); codeModelInstance!.SetValue(null, mockGeneratorInstance.Object); clientModelInstance!.SetValue(null, mockGeneratorInstance.Object); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/src/InputTypes/InputModelTypeUsage.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/src/InputTypes/InputModelTypeUsage.cs index 52d9c8c0bc8..f96ecad090e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/src/InputTypes/InputModelTypeUsage.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/src/InputTypes/InputModelTypeUsage.cs @@ -22,6 +22,5 @@ public enum InputModelTypeUsage LroInitial = 2048, LroPolling = 4096, LroFinalEnvelope = 8192, - External = 16384, } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/test/TypeSpecInputConverterTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/test/TypeSpecInputConverterTests.cs index 7e5ec1c3c84..9328dc43571 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/test/TypeSpecInputConverterTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.Input/test/TypeSpecInputConverterTests.cs @@ -564,42 +564,6 @@ public void DeserializeModelWithExternalMetadata() Assert.AreEqual("8.0.0", model.External.MinVersion); } - [Test] - public void DeserializeModelWithExternalUsagePreservesInputAndOutput() - { - // TCGC emits the External usage flag (UsageFlags.External) for models that are also - // referenced by external types. The C# InputModelTypeUsage enum must recognize it so - // that Enum.TryParse does not fail on the unknown token and collapse the whole usage to - // None, which would strip Input/Output and make every property get-only. - var json = @"{ - ""$id"": ""1"", - ""kind"": ""model"", - ""name"": ""TestModel"", - ""namespace"": ""Test.Models"", - ""crossLanguageDefinitionId"": ""Test.Models.TestModel"", - ""usage"": ""Input,Output,External"", - ""properties"": [] - }"; - - var referenceHandler = new TypeSpecReferenceHandler(); - var options = new JsonSerializerOptions - { - AllowTrailingCommas = true, - Converters = - { - new InputTypeConverter(referenceHandler), - new InputModelTypeConverter(referenceHandler), - new InputExternalTypeMetadataConverter() - } - }; - - var model = JsonSerializer.Deserialize(json, options); - Assert.IsNotNull(model); - Assert.IsTrue(model!.Usage.HasFlag(InputModelTypeUsage.Input), "Model should retain Input usage flag"); - Assert.IsTrue(model.Usage.HasFlag(InputModelTypeUsage.Output), "Model should retain Output usage flag"); - Assert.IsTrue(model.Usage.HasFlag(InputModelTypeUsage.External), "Model should have External usage flag"); - } - [Test] public void DeserializeArrayWithExternalMetadata() { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs index c013817a72e..35914c1f8d4 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/CSharpGen.cs @@ -27,12 +27,13 @@ public async Task ExecuteAsync() { CodeModelGenerator.Instance.Emitter.Info("Starting code generation"); CodeModelGenerator.Instance.Stopwatch.Start(); + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); var outputPath = CodeModelGenerator.Instance.Configuration.OutputDirectory; var generatedSourceOutputPath = CodeModelGenerator.Instance.Configuration.ProjectGeneratedDirectory; - // Resolve PackageReference items from the .csproj so custom code referencing - // external NuGet types (e.g., Azure.Storage.Common) compiles correctly. + // Resolve PackageReference items from the .csproj so custom code referencing external + // NuGet types compiles correctly. await GeneratedCodeWorkspace.AddPackageReferencesFromProject(); // Pre-walk the input library and resolve any external types that point at NuGet packages. @@ -90,12 +91,33 @@ await GeneratedCodeWorkspace.LoadBaselineContract(), { // Ensure back-compatibility processing is done after all visitors have run outputType.ProcessTypeForBackCompatibility(); + } + + generatedCodeWorkspace.ApplyPreWriteAccessibility(output.TypeProviders); + generatedCodeWorkspace.AnalyzeProviderReferenceMap(output.TypeProviders); + + foreach (var outputType in output.TypeProviders) + { + if (!ProviderReferenceMapAnalyzer.ShouldWriteProvider(outputType)) + { + continue; + } + + if (outputType is ModelFactoryProvider && outputType.Methods.Count == 0) + { + continue; + } var writer = CodeModelGenerator.Instance.GetWriter(outputType); generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); foreach (var serialization in outputType.SerializationProviders) { + if (!ProviderReferenceMapAnalyzer.ShouldWriteProvider(serialization)) + { + continue; + } + writer = CodeModelGenerator.Instance.GetWriter(serialization); generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write())); } @@ -104,6 +126,8 @@ await GeneratedCodeWorkspace.LoadBaselineContract(), // Add all the generated files to the workspace await Task.WhenAll(generateFilesTasks); + ProviderReferenceMapAnalyzer.RestorePreWriteModelFactoryMethods(); + LoggingHelpers.LogElapsedTime("All generated types have been written into memory"); // Delete any old generated files @@ -112,14 +136,22 @@ await GeneratedCodeWorkspace.LoadBaselineContract(), LoggingHelpers.LogElapsedTime("All old generated files have been deleted"); await generatedCodeWorkspace.PostProcessAsync(); + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); - // Write the generated files to the output directory + var generatedFiles = new List<(string Name, string Text)>(); await foreach (var file in generatedCodeWorkspace.GetGeneratedFilesAsync()) { if (string.IsNullOrEmpty(file.Text)) { continue; } + + generatedFiles.Add((file.Name, file.Text)); + } + + // Write the generated files to the output directory + foreach (var file in generatedFiles) + { var filename = Path.Combine(outputPath, file.Name); CodeModelGenerator.Instance.Emitter.Info($"Writing {Path.GetFullPath(filename)}"); Directory.CreateDirectory(Path.GetDirectoryName(filename)!); @@ -177,9 +209,10 @@ private static void DeleteDirectory(string path, string[] filesToKeep) return; } + var fileNamesToKeep = filesToKeep.ToHashSet(StringComparer.Ordinal); foreach (var file in directoryInfo.GetFiles("*", SearchOption.AllDirectories)) { - if (!filesToKeep.Contains(file.Name)) + if (!fileNamesToKeep.Contains(file.Name)) { file.Delete(); } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs index e9e016c3554..6df7ccc9758 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/LibraryVisitor.cs @@ -168,7 +168,7 @@ protected internal virtual void VisitLibrary(OutputLibrary library) /// /// The original . /// Null if it should be removed otherwise the modified version of the . - protected internal virtual ConstructorProvider? VisitConstructor(ConstructorProvider constructor) + protected virtual ConstructorProvider? VisitConstructor(ConstructorProvider constructor) { return constructor; } @@ -302,7 +302,7 @@ protected internal virtual FinallyExpression VisitFinallyExpression(FinallyExpre /// /// The original . /// Null if it should be removed otherwise the modified version of the . - protected internal virtual FieldProvider? VisitField(FieldProvider field) + protected virtual FieldProvider? VisitField(FieldProvider field) { return field; } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs index c36686f637f..74f9fc6b971 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/GeneratedCodeWorkspace.cs @@ -84,6 +84,16 @@ public async Task AddInMemoryFile(TypeProvider type) await UpdateProject(document); } + internal void AnalyzeProviderReferenceMap(IReadOnlyList providers) + { + ProviderReferenceMapAnalyzer.Analyze(providers); + } + + internal void ApplyPreWriteAccessibility(IReadOnlyList providers) + { + ProviderReferenceMapAnalyzer.ApplyPreWriteAccessibility(providers); + } + private async Task UpdateProject(Document document) { var root = await document.GetSyntaxRootAsync(); @@ -278,10 +288,8 @@ public async Task PostProcessAsync() case Configuration.UnreferencedTypesHandlingOption.KeepAll: break; case Configuration.UnreferencedTypesHandlingOption.Internalize: - _project = await postProcessor.InternalizeAsync(_project); break; case Configuration.UnreferencedTypesHandlingOption.RemoveOrInternalize: - _project = await postProcessor.InternalizeAsync(_project); _project = await postProcessor.RemoveAsync(_project); break; } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs index dc42f801732..bef5dce85b4 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/PostProcessor.cs @@ -9,7 +9,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Simplification; namespace Microsoft.TypeSpec.Generator { @@ -113,58 +112,6 @@ private async Task GetTypeSymbolsAsync(Compilation compilation, protected virtual bool ShouldIncludeDocument(Document document) => !GeneratedCodeWorkspace.IsGeneratedTestDocument(document); - /// - /// This method marks the "not publicly" referenced types as internal if they are previously defined as public. It will do this job in the following steps: - /// 1. This method will read all the public types defined in the given , and build a cache for those symbols - /// 2. Build a public reference map for those symbols - /// 3. Finds all the root symbols, please override the to control which document you would like to include - /// 4. Visit all the symbols starting from the root symbols following the reference map to get all unvisited symbols - /// 5. Change the accessibility of the unvisited symbols in step 4 to internal - /// - /// The project to process - /// The processed . is immutable, therefore this should usually be a new instance - public async Task InternalizeAsync(Project project) - { - var compilation = await project.GetCompilationAsync(); - if (compilation == null) - { - return project; - } - - // first get all the declared symbols - var definitions = await GetTypeSymbolsAsync(compilation, project, true); - // build the reference map - var referenceMap = - await new ReferenceMapBuilder(compilation, project).BuildPublicReferenceMapAsync( - definitions.DeclaredSymbols, definitions.DeclaredNodesCache); - // get the root symbols - var rootSymbols = await GetRootSymbolsAsync(project, definitions); - // traverse all the root and recursively add all the things we met - var publicSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); - - var symbolsToInternalize = definitions.DeclaredSymbols.Except(publicSymbols); - - var nodesToInternalize = new Dictionary(); - foreach (var symbol in symbolsToInternalize) - { - foreach (var node in definitions.DeclaredNodesCache[symbol]) - { - nodesToInternalize[node] = project.GetDocumentId(node.SyntaxTree)!; - } - } - - foreach (var (model, documentId) in nodesToInternalize) - { - project = MarkInternal(project, model, documentId); - } - - var modelNamesToRemove = - nodesToInternalize.Keys.Select(item => item.Identifier.Text); - project = await RemoveMethodsFromModelFactoryAsync(project, definitions, modelNamesToRemove.ToHashSet()); - - return project; - } - private async Task RemoveMethodsFromModelFactoryAsync(Project project, TypeSymbols definitions, HashSet namesToRemove) @@ -246,25 +193,32 @@ public async Task RemoveAsync(Project project) // find all the declarations, including non-public declared var definitions = await GetTypeSymbolsAsync(compilation, project, false); - // build reference map - var referenceMap = - await new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( - definitions.DeclaredSymbols, definitions.DocumentsCache); - // get root symbols - var rootSymbols = await GetRootSymbolsAsync(project, definitions); - // include model factory as a root symbol when doing the remove pass so that we are sure to include any internal - // helpers that are required by the model factory. - if (_modelFactorySymbol != null) + IEnumerable symbolsToRemove; + HashSet referencedSet; + if (ProviderReferenceMapAnalyzer.LatestResult is { } referenceMapResult) { - rootSymbols.Add(_modelFactorySymbol); + // The remove pass uses the same precomputed hybrid map to avoid scanning all generated + // documents with Roslyn while preserving custom-code references as roots. + symbolsToRemove = GetSymbolsByName(definitions.DeclaredSymbols, referenceMapResult.RemoveCandidates).ToArray(); + referencedSet = new HashSet(definitions.DeclaredSymbols.Except(symbolsToRemove), SymbolEqualityComparer.Default); } - // traverse the map to determine the declarations that we are about to remove, starting from root nodes - var referencedSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); - - referencedSymbols = AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols); - var referencedSet = new HashSet(referencedSymbols, SymbolEqualityComparer.Default); + else + { + var referenceMap = await new ReferenceMapBuilder(compilation, project).BuildAllReferenceMapAsync( + definitions.DeclaredSymbols, definitions.DocumentsCache); + // Include model factory as a root symbol when doing the remove pass so that we are sure to include any internal + // helpers that are required by the model factory. + var rootSymbols = await GetRootSymbolsAsync(project, definitions); + if (_modelFactorySymbol != null) + { + rootSymbols.Add(_modelFactorySymbol); + } - var symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); + var referencedSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap); + referencedSymbols = AddSampleSymbols(referencedSymbols, definitions.DeclaredSymbols); + referencedSet = new HashSet(referencedSymbols, SymbolEqualityComparer.Default); + symbolsToRemove = definitions.DeclaredSymbols.Except(referencedSet); + } var nodesToRemove = new List(); foreach (var symbol in symbolsToRemove) @@ -276,6 +230,14 @@ public async Task RemoveAsync(Project project) nodesToRemove.AddRange(definitions.DeclaredNodesCache[symbol]); } + var modelNamesToRemove = nodesToRemove + .Select(static item => item.Identifier.Text) + .ToHashSet(StringComparer.Ordinal); + if (modelNamesToRemove.Count > 0) + { + project = await RemoveMethodsFromModelFactoryAsync(project, definitions, modelNamesToRemove); + } + // remove them one by one project = await RemoveModelsAsync(project, nodesToRemove); @@ -352,18 +314,19 @@ private static IEnumerable GetReferencedTypes(T definition, return Enumerable.Empty(); } - private Project MarkInternal(Project project, BaseTypeDeclarationSyntax declarationNode, DocumentId documentId) + private static IEnumerable GetSymbolsByName(IEnumerable symbols, HashSet names) { - var newNode = ChangeModifier(declarationNode, SyntaxKind.PublicKeyword, SyntaxKind.InternalKeyword); - var tree = declarationNode.SyntaxTree; - var document = project.GetDocument(documentId)!; - var newRoot = tree.GetRoot().ReplaceNode(declarationNode, newNode) - .WithAdditionalAnnotations(Simplifier.Annotation); - document = document.WithSyntaxRoot(newRoot); - return document.Project; + foreach (var symbol in symbols) + { + if (names.Contains(symbol.GetFullyQualifiedName())) + { + yield return symbol; + } + } } - private async Task RemoveModelsAsync(Project project, + private async Task RemoveModelsAsync( + Project project, IEnumerable unusedModels) { // accumulate the definitions from the same document together @@ -392,24 +355,6 @@ private async Task RemoveModelsAsync(Project project, return project; } - private static BaseTypeDeclarationSyntax ChangeModifier(BaseTypeDeclarationSyntax memberDeclaration, - SyntaxKind from, - SyntaxKind to) - { - var originalTokenInList = memberDeclaration.Modifiers.FirstOrDefault(token => token.IsKind(from)); - - // skip this if there is nothing to replace - if (originalTokenInList == default) - { - return memberDeclaration; - } - - var newToken = - SyntaxFactory.Token(originalTokenInList.LeadingTrivia, to, originalTokenInList.TrailingTrivia); - var newModifiers = memberDeclaration.Modifiers.Replace(originalTokenInList, newToken); - return memberDeclaration.WithModifiers(newModifiers); - } - private async Task RemoveModelsFromDocumentAsync(Project project, IEnumerable models) { @@ -479,7 +424,14 @@ private async Task RemoveInvalidUsings(Solution solution, DocumentId d if (invalidUsings.Count > 0) { + var leadingTrivia = invalidUsings[0].GetLeadingTrivia(); cu = cu.RemoveNodes(invalidUsings, SyntaxRemoveOptions.KeepNoTrivia)!; + if (leadingTrivia.Count > 0) + { + var firstToken = cu.GetFirstToken(includeZeroWidth: true); + cu = cu.ReplaceToken(firstToken, firstToken.WithLeadingTrivia(leadingTrivia.AddRange(firstToken.LeadingTrivia))); + } + solution = solution.WithDocumentSyntaxRoot(documentId, cu); } @@ -497,30 +449,37 @@ private async Task RemoveInvalidAttributes(Solution solution, Document return solution; } - var attributes = cu.DescendantNodes().OfType(); - var firstAttribute = attributes.FirstOrDefault(); + var attributeLists = cu.DescendantNodes().OfType().ToArray(); + var firstAttributeList = attributeLists.FirstOrDefault(); - var invalidAttributes = attributes - .Where(attr => attr.Attributes.Any(attribute => + var invalidAttributes = attributeLists + .SelectMany(static attr => attr.Attributes) + .Where(attribute => attribute.ArgumentList?.Arguments.Any(arg => arg.Expression is TypeOfExpressionSyntax typeOfExpr && - model.GetTypeInfo(typeOfExpr.Type).Type?.TypeKind == TypeKind.Error) == true)) + model.GetTypeInfo(typeOfExpr.Type).Type?.TypeKind == TypeKind.Error) == true) .ToHashSet(); if (invalidAttributes.Count > 0) { + var firstAttributeListRemoved = firstAttributeList != null && + firstAttributeList.Attributes.All(invalidAttributes.Contains); + var leadingTrivia = firstAttributeList?.GetLeadingTrivia(); cu = cu.RemoveNodes(invalidAttributes, SyntaxRemoveOptions.KeepNoTrivia)!; + var emptyAttributeLists = cu.DescendantNodes().OfType() + .Where(static list => list.Attributes.Count == 0) + .ToArray(); + cu = cu.RemoveNodes(emptyAttributeLists, SyntaxRemoveOptions.KeepNoTrivia)!; - if (invalidAttributes.Contains(firstAttribute!)) + if (firstAttributeListRemoved && leadingTrivia != null) { - var leadingTrivia = firstAttribute!.GetLeadingTrivia(); // Find where XML docs end and indentation begins var xmlDocTrivia = new List(); var lastXmlIndex = -1; - for (int i = 0; i < leadingTrivia.Count; i++) + for (int i = 0; i < leadingTrivia.Value.Count; i++) { - var trivia = leadingTrivia[i]; + var trivia = leadingTrivia.Value[i]; if (trivia.IsKind(SyntaxKind.SingleLineDocumentationCommentTrivia)) { lastXmlIndex = i; @@ -532,14 +491,14 @@ arg.Expression is TypeOfExpressionSyntax typeOfExpr && { for (int i = 0; i <= lastXmlIndex; i++) { - xmlDocTrivia.Add(leadingTrivia[i]); + xmlDocTrivia.Add(leadingTrivia.Value[i]); } // Include the newline after the last XML doc if present - if (lastXmlIndex + 1 < leadingTrivia.Count && - leadingTrivia[lastXmlIndex + 1].IsKind(SyntaxKind.EndOfLineTrivia)) + if (lastXmlIndex + 1 < leadingTrivia.Value.Count && + leadingTrivia.Value[lastXmlIndex + 1].IsKind(SyntaxKind.EndOfLineTrivia)) { - xmlDocTrivia.Add(leadingTrivia[lastXmlIndex + 1]); + xmlDocTrivia.Add(leadingTrivia.Value[lastXmlIndex + 1]); } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs new file mode 100644 index 00000000000..eafe1d9d546 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/PostProcessing/ProviderReferenceMapResult.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System.Collections.Generic; + +namespace Microsoft.TypeSpec.Generator +{ + internal sealed record ProviderReferenceMapResult( + HashSet InternalizeCandidates, + HashSet PublicizeCandidates, + HashSet RemoveCandidates) + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/PropertyDescriptionBuilder.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/PropertyDescriptionBuilder.cs index 37bac4c1202..33d0fa5eb32 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/PropertyDescriptionBuilder.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/PropertyDescriptionBuilder.cs @@ -65,7 +65,7 @@ internal static IReadOnlyList GetUnionTypesDescriptions(IReadOn } else { - description = new XmlDocStatement("description", [$"{item:C}"]); + description = new XmlDocStatement("description", [$"{item}"]); } values.Add(description); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/TypeProviderWriter.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/TypeProviderWriter.cs index 49fe9723973..eb07aa4519f 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/TypeProviderWriter.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Primitives/TypeProviderWriter.cs @@ -45,7 +45,7 @@ private bool IsPublicContext(TypeProvider provider) private void WriteType(CodeWriter writer) { - if (IsPublicContext(_provider)) + if (_provider.PreserveTypeXmlDocs || IsPublicContext(_provider)) { writer.WriteXmlDocsNoScope(_provider.XmlDocs); } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/NamedTypeSymbolProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/NamedTypeSymbolProvider.cs index ed3a45d54e8..26cc96af377 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/NamedTypeSymbolProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/NamedTypeSymbolProvider.cs @@ -22,6 +22,7 @@ internal sealed class NamedTypeSymbolProvider : TypeProvider { private INamedTypeSymbol _namedTypeSymbol; private readonly Compilation _compilation; + private string? _metadataName; private TypeProvider? _baseTypeProvider; public NamedTypeSymbolProvider(INamedTypeSymbol namedTypeSymbol, Compilation compilation) @@ -30,6 +31,23 @@ public NamedTypeSymbolProvider(INamedTypeSymbol namedTypeSymbol, Compilation com _compilation = compilation; } + internal string MetadataName + { + get + { + if (_metadataName != null) + { + return _metadataName; + } + + var ns = _namedTypeSymbol.ContainingNamespace.GetFullyQualifiedNameFromDisplayString(); + _metadataName = string.IsNullOrEmpty(ns) ? _namedTypeSymbol.Name : $"{ns}.{_namedTypeSymbol.Name}"; + return _metadataName; + } + } + + internal string MetadataSimpleName => _namedTypeSymbol.Name; + private protected sealed override NamedTypeSymbolProvider? BuildCustomCodeView(string? generatedTypeName = default, string? generatedTypeNamespace = default) => null; private protected sealed override TypeProvider? BuildLastContractView(string? generatedTypeName = default, string? generatedTypeNamespace = default) => null; @@ -321,6 +339,165 @@ [.. methodSymbol.Parameters.Select(p => ConvertToParameterProvider(methodSymbol, return [.. methods]; } + protected internal override IReadOnlyList BuildBodyDependencyTypes() + { + var dependencies = new HashSet(); + foreach (var syntaxReference in _namedTypeSymbol.DeclaringSyntaxReferences) + { + AddBodyDependencyTypes(syntaxReference.GetSyntax(), dependencies); + } + + return [.. dependencies]; + } + + protected internal override IReadOnlyList BuildSignatureDependencyTypes() + { + var dependencies = new HashSet(); + foreach (var syntaxReference in _namedTypeSymbol.DeclaringSyntaxReferences) + { + if (syntaxReference.GetSyntax() is not TypeDeclarationSyntax typeDeclaration || + !IsPublic(typeDeclaration.Modifiers)) + { + continue; + } + + AddSyntaxTypeReferences(typeDeclaration.BaseList, dependencies); + foreach (var member in typeDeclaration.Members) + { + if (IsPublicApiMember(member)) + { + AddPublicSignatureDependencyTypes(member, dependencies); + } + } + } + + return [.. dependencies]; + } + + private void AddBodyDependencyTypes(SyntaxNode syntax, HashSet dependencies) + { + AddSyntaxTypeReferences(syntax, dependencies); + + foreach (var invocation in syntax.DescendantNodes().OfType()) + { + if (GetInvocationName(invocation) == "SetDelimited") + { + dependencies.Add(CreateUnresolvedDependencyType("SetDelimited")); + } + } + } + + private static void AddPublicSignatureDependencyTypes(MemberDeclarationSyntax member, HashSet dependencies) + { + switch (member) + { + case MethodDeclarationSyntax method: + AddSyntaxTypeReferences(method.ReturnType, dependencies); + AddSyntaxTypeReferences(method.ParameterList, dependencies); + AddSyntaxTypeReferences(method.ConstraintClauses, dependencies); + break; + case ConstructorDeclarationSyntax constructor: + AddSyntaxTypeReferences(constructor.ParameterList, dependencies); + break; + case ConversionOperatorDeclarationSyntax conversion: + AddSyntaxTypeReferences(conversion.Type, dependencies); + AddSyntaxTypeReferences(conversion.ParameterList, dependencies); + break; + case OperatorDeclarationSyntax @operator: + AddSyntaxTypeReferences(@operator.ReturnType, dependencies); + AddSyntaxTypeReferences(@operator.ParameterList, dependencies); + break; + case PropertyDeclarationSyntax property: + AddSyntaxTypeReferences(property.Type, dependencies); + break; + case IndexerDeclarationSyntax indexer: + AddSyntaxTypeReferences(indexer.Type, dependencies); + AddSyntaxTypeReferences(indexer.ParameterList, dependencies); + break; + case FieldDeclarationSyntax field: + AddSyntaxTypeReferences(field.Declaration.Type, dependencies); + break; + case EventFieldDeclarationSyntax eventField: + AddSyntaxTypeReferences(eventField.Declaration.Type, dependencies); + break; + case EventDeclarationSyntax @event: + AddSyntaxTypeReferences(@event.Type, dependencies); + break; + case DelegateDeclarationSyntax @delegate: + AddSyntaxTypeReferences(@delegate.ReturnType, dependencies); + AddSyntaxTypeReferences(@delegate.ParameterList, dependencies); + AddSyntaxTypeReferences(@delegate.ConstraintClauses, dependencies); + break; + case BaseTypeDeclarationSyntax type: + AddSyntaxTypeReferences(type.BaseList, dependencies); + break; + } + } + + private static void AddSyntaxTypeReferences(SyntaxNode? node, HashSet dependencies) + { + if (node == null) + { + return; + } + + foreach (var name in node.DescendantNodesAndSelf().OfType()) + { + dependencies.Add(CreateUnresolvedDependencyType(name.Identifier.ValueText)); + } + + foreach (var name in node.DescendantNodesAndSelf().OfType()) + { + dependencies.Add(CreateUnresolvedDependencyType(name.Identifier.ValueText)); + } + } + + private static void AddSyntaxTypeReferences(IEnumerable nodes, HashSet dependencies) + { + foreach (var node in nodes) + { + AddSyntaxTypeReferences(node, dependencies); + } + } + + private static bool IsPublicApiMember(MemberDeclarationSyntax member) + => member switch + { + EventDeclarationSyntax @event => IsPublic(@event.Modifiers), + EventFieldDeclarationSyntax @event => IsPublic(@event.Modifiers), + BaseFieldDeclarationSyntax field => IsPublic(field.Modifiers), + BaseMethodDeclarationSyntax method => IsPublic(method.Modifiers), + BasePropertyDeclarationSyntax property => IsPublic(property.Modifiers), + DelegateDeclarationSyntax @delegate => IsPublic(@delegate.Modifiers), + BaseTypeDeclarationSyntax type => IsPublic(type.Modifiers), + _ => false + }; + + private static bool IsPublic(SyntaxTokenList modifiers) + => modifiers.Any(static modifier => + modifier.IsKind(SyntaxKind.PublicKeyword) || + modifier.IsKind(SyntaxKind.ProtectedKeyword)); + + private static CSharpType CreateUnresolvedDependencyType(string name) + => new( + name, + string.Empty, + isValueType: false, + isNullable: false, + declaringType: null, + args: [], + isPublic: false, + isStruct: false); + + private static string? GetInvocationName(InvocationExpressionSyntax invocation) + => invocation.Expression switch + { + IdentifierNameSyntax identifier => identifier.Identifier.ValueText, + MemberAccessExpressionSyntax memberAccess => memberAccess.Name.Identifier.ValueText, + GenericNameSyntax genericName => genericName.Identifier.ValueText, + _ => null + }; + private static bool IsPartialMethodDeclaration(IMethodSymbol methodSymbol) { foreach (var syntaxReference in methodSymbol.DeclaringSyntaxReferences) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs index 3d71670d5a9..bbaf3e54d24 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/TypeProvider.cs @@ -143,6 +143,13 @@ public XmlDocProvider XmlDocs private set => _xmlDocs = value; } + internal bool PreserveTypeXmlDocs { get; private set; } + + internal void PreserveXmlDocs() + { + PreserveTypeXmlDocs = true; + } + public string? Deprecated { get => _deprecated; @@ -292,6 +299,22 @@ private IReadOnlyList ApplyCustomizationFilter(IEnumerable SerializationProviders => _serializationProviders ??= BuildSerializationProviders(); + private IReadOnlyList? _helperDependencyTypes; + internal IReadOnlyList HelperDependencyTypes => _helperDependencyTypes ??= BuildHelperDependencyTypes(); + protected internal virtual IReadOnlyList BuildHelperDependencyTypes() => []; + + private IReadOnlyList? _bodyDependencyTypes; + internal IReadOnlyList BodyDependencyTypes => _bodyDependencyTypes ??= BuildBodyDependencyTypes(); + protected internal virtual IReadOnlyList BuildBodyDependencyTypes() => []; + + private IReadOnlyList? _signatureDependencyTypes; + internal IReadOnlyList SignatureDependencyTypes => _signatureDependencyTypes ??= BuildSignatureDependencyTypes(); + protected internal virtual IReadOnlyList BuildSignatureDependencyTypes() => []; + + protected internal virtual bool IsClientProvider => false; + + protected internal virtual bool IncludeGeneratedBodyReferences => false; + private IReadOnlyList? _attributes; public IReadOnlyList Attributes @@ -538,6 +561,7 @@ public virtual void Reset() _serializationProviders = null; _nestedTypes = null; _xmlDocs = null; + PreserveTypeXmlDocs = false; _declarationModifiers = null; _relativeFilePath = null; _customCodeView = new(() => BuildCustomCodeView()); @@ -741,75 +765,10 @@ internal void ProcessTypeForBackCompatibility() { _enumValues = updatedEnumValues; } - - // Back-compatibility processing intentionally runs after the library visitor pass so - // that the contract comparison uses the final, post-visitor member signatures (otherwise - // we could incorrectly decide whether a back-compat member is needed). As a result, any - // members synthesized above (e.g. back-compat overloads) have not been visited yet. Run - // only those newly-added members through the visitors now so visitor transforms apply to - // them as well, without re-visiting members that were already visited during the main pass. - if (newMethods != null) - { - newMethods = VisitNewMembers(newMethods, Methods, static (member, visitor) => member.Accept(visitor)); - } - if (newConstructors != null) - { - newConstructors = VisitNewMembers(newConstructors, Constructors, static (member, visitor) => visitor.VisitConstructor(member)); - } - if (newFields != null) - { - newFields = VisitNewMembers(newFields, Fields, static (member, visitor) => visitor.VisitField(member)); - } - Update(fields: newFields, methods: newMethods, constructors: newConstructors); } } - // Runs newly-added back-compatibility members through every registered visitor while leaving - // members that were already visited during the main visitor pass untouched. Membership in the - // already-visited set is determined by reference identity against the pre-Update collection. - private static IReadOnlyList VisitNewMembers( - IEnumerable allMembers, - IReadOnlyList alreadyVisited, - Func visit) - where T : class - { - var visitors = CodeModelGenerator.Instance.Visitors; - var materialized = allMembers as IReadOnlyList ?? [.. allMembers]; - if (visitors.Count == 0) - { - return materialized; - } - - var alreadyVisitedSet = new HashSet(alreadyVisited, ReferenceEqualityComparer.Instance); - var result = new List(materialized.Count); - foreach (var member in materialized) - { - if (alreadyVisitedSet.Contains(member)) - { - result.Add(member); - continue; - } - - T? visited = member; - foreach (var visitor in visitors) - { - visited = visit(visited, visitor); - if (visited == null) - { - break; - } - } - - if (visited != null) - { - result.Add(visited); - } - } - - return result; - } - protected internal virtual IReadOnlyList? BuildEnumValuesForBackCompatibility(IReadOnlyList originalEnumValues) => null; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs new file mode 100644 index 00000000000..9fb12dd028d --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.BodyReferences.cs @@ -0,0 +1,401 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Microsoft.TypeSpec.Generator.Expressions; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Statements; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static void AddGeneratedBodyReferences(IReadOnlyList providers, ProviderReferenceGraph graph) + { + foreach (var (provider, isSerializationProvider) in GetBodyReferenceProviders(providers)) + { + if (IsModelFactoryProvider(provider) || + !IsGeneratedBodyReferenceCandidate(provider, isSerializationProvider)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!graph.Nodes.Contains(providerName)) + { + continue; + } + + AddHelperDependencies(graph.References[providerName], provider.HelperDependencyTypes, graph.Nodes, referencedNames: null); + AddProviderBodyDependencyTypes( + graph.References[providerName], + GetNonEnumStructuredBodyReferenceTypes(provider, graph.Nodes), + graph.Nodes); + AddProviderBodyDependencyTypes(graph.References[providerName], provider.BodyDependencyTypes, graph.Nodes); + AddProviderInfrastructureReferences(graph.References[providerName], provider, graph.Nodes); + } + } + + private static void AddProviderInfrastructureReferences(HashSet references, TypeProvider provider, HashSet nodes) + { + AddMatchingName(references, "ProviderConstants", nodes); + AddMatchingName(references, "TypeFormatters", nodes); + + if (provider.SerializationProviders.Count > 0) + { + AddSerializationExtensionReferences(references, provider, nodes); + } + + if (IsSerializationProvider(provider)) + { + AddMatchingName(references, "Optional", nodes); + AddMatchingName(references, "Utf8JsonRequestContent", nodes); + AddMatchingName(references, "ModelSerializationExtensions", nodes); + AddSerializationExtensionReferences(references, provider, nodes); + } + + foreach (var method in provider.Methods) + { + AddMethodInfrastructureReferences(references, method, nodes); + } + } + + private static void AddSerializationExtensionReferences(HashSet references, TypeProvider provider, HashSet nodes) + { + AddSerializationExtensionReferences(references, provider.Type, nodes); + AddSerializationExtensionReferences(references, provider.BaseType, nodes); + foreach (var implementedType in provider.Implements) + { + AddSerializationExtensionReferences(references, implementedType, nodes); + } + + foreach (var property in provider.Properties) + { + AddSerializationExtensionReferences(references, property.Type, nodes); + } + + foreach (var field in provider.Fields) + { + AddSerializationExtensionReferences(references, field.Type, nodes); + } + + foreach (var constructor in provider.Constructors) + { + AddSerializationExtensionReferences(references, constructor.Signature.ReturnType, nodes); + foreach (var parameter in constructor.Signature.Parameters) + { + AddSerializationExtensionReferences(references, parameter.Type, nodes); + } + } + + foreach (var method in provider.Methods) + { + AddSerializationExtensionReferences(references, method.Signature.ReturnType, nodes); + foreach (var parameter in method.Signature.Parameters) + { + AddSerializationExtensionReferences(references, parameter.Type, nodes); + } + } + } + + private static void AddSerializationExtensionReferences(HashSet references, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + AddMatchingName(references, $"{type.Name}Extensions", nodes); + foreach (var argument in type.Arguments) + { + AddSerializationExtensionReferences(references, argument, nodes); + } + } + + private static void AddMethodInfrastructureReferences(HashSet references, MethodProvider method, HashSet nodes) + { + AddReturnTypeInfrastructureReferences(references, method.Signature.ReturnType, nodes); + foreach (var parameter in method.Signature.Parameters) + { + AddRequestContentInfrastructureReferences(references, parameter.Type, nodes); + } + } + + private static void AddReturnTypeInfrastructureReferences(HashSet references, CSharpType? returnType, HashSet nodes) + { + var type = UnwrapTask(returnType); + if (type == null) + { + return; + } + + var typeName = StripGenericArity(type.Name); + if (string.Equals(typeName, "Pageable", StringComparison.Ordinal)) + { + AddMatchingName(references, "PageableWrapper", nodes); + } + else if (string.Equals(typeName, "AsyncPageable", StringComparison.Ordinal)) + { + AddMatchingName(references, "AsyncPageableWrapper", nodes); + } + else if (string.Equals(typeName, "ArmOperation", StringComparison.Ordinal)) + { + AddMatchingNamesWithSimpleNameSuffix(references, "ArmOperation", nodes); + AddMatchingNamesWithSimpleNameSuffix(references, "OperationSource", nodes); + if (type.Arguments.Count > 0) + { + AddMatchingName(references, $"{BuildOperationSourceTypeName(type.Arguments[0])}OperationSource", nodes); + } + } + } + + private static void AddRequestContentInfrastructureReferences(HashSet references, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + if (string.Equals(type.Name, "RequestContent", StringComparison.Ordinal)) + { + AddMatchingName(references, "BinaryContentHelper", nodes); + AddMatchingName(references, "Utf8JsonRequestContent", nodes); + } + + foreach (var argument in type.Arguments) + { + AddRequestContentInfrastructureReferences(references, argument, nodes); + } + } + + private static CSharpType? UnwrapTask(CSharpType? type) + { + var typeName = type == null ? null : StripGenericArity(type.Name); + if ((string.Equals(typeName, "Task", StringComparison.Ordinal) || + string.Equals(typeName, "ValueTask", StringComparison.Ordinal)) && + type?.Arguments.Count > 0) + { + return type.Arguments[0]; + } + + return type; + } + + private static string BuildOperationSourceTypeName(CSharpType type) + { + var argumentNames = string.Join("", type.Arguments.Select(BuildOperationSourceTypeName)); + return $"{type.Name}{(argumentNames.Length > 0 ? "Of" : string.Empty)}{argumentNames}"; + } + + private static IReadOnlyList GetNonEnumStructuredBodyReferenceTypes(TypeProvider provider, HashSet nodes) + { + var references = new List(); + foreach (var dependency in CollectStructuredBodyReferenceTypes(provider)) + { + if (!IsEnumProviderDependency(dependency, nodes)) + { + references.Add(dependency); + } + } + + return references; + } + + private static IReadOnlyList CollectStructuredBodyReferenceTypes(TypeProvider provider) + { + var references = new HashSet(); + var visited = new HashSet(ReferenceEqualityComparer.Instance); + + foreach (var field in provider.Fields) + { + CollectStructuredBodyReferenceTypes(field.InitializationValue, references, visited); + } + + foreach (var property in provider.Properties) + { + CollectStructuredBodyReferenceTypes(property.Body, references, visited); + } + + foreach (var constructor in provider.Constructors) + { + CollectStructuredBodyReferenceTypes(constructor.BodyExpression, references, visited); + CollectStructuredBodyReferenceTypes(constructor.BodyStatements, references, visited); + } + + foreach (var method in provider.Methods) + { + if (method.IsMethodSuppressed()) + { + continue; + } + + CollectStructuredBodyReferenceTypes(method.BodyExpression, references, visited); + CollectStructuredBodyReferenceTypes(method.BodyStatements, references, visited); + } + + return [.. references]; + } + + private static void CollectStructuredBodyReferenceTypes(object? value, HashSet references, HashSet visited) + { + switch (value) + { + case null: + case string: + case FormattableString: + return; + } + + if (!value.GetType().IsValueType && !visited.Add(value)) + { + return; + } + + switch (value) + { + case CSharpType type: + references.Add(type); + return; + case Type type: + references.Add(type); + return; + case ParameterProvider parameter: + references.Add(parameter.Type); + CollectStructuredBodyReferenceTypes(parameter.DefaultValue, references, visited); + CollectStructuredBodyReferenceTypes(parameter.InitializationValue, references, visited); + return; + case MethodSignatureBase signature: + CollectStructuredBodyReferenceTypes(signature.ReturnType, references, visited); + CollectStructuredBodyReferenceTypes(signature.Parameters, references, visited); + return; + case KeyValuePair positionalArgument: + CollectStructuredBodyReferenceTypes(positionalArgument.Value, references, visited); + return; + case FieldProvider field: + references.Add(field.Type); + CollectStructuredBodyReferenceTypes(field.InitializationValue, references, visited); + return; + } + + if (IsStructuredBodyReferenceObject(value)) + { + foreach (var property in value.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance)) + { + if (property.GetIndexParameters().Length > 0) + { + continue; + } + + CollectStructuredBodyReferenceTypes(property.GetValue(value), references, visited); + } + + return; + } + + if (value is not IEnumerable values) + { + return; + } + + foreach (var item in values) + { + CollectStructuredBodyReferenceTypes(item, references, visited); + } + } + + private static bool IsEnumProviderDependency(CSharpType dependency, HashSet nodes) + { + var providerName = GetProviderTypeName(dependency); + if (!nodes.Contains(providerName)) + { + return false; + } + + foreach (var provider in CodeModelGenerator.Instance.OutputLibrary.TypeProviders) + { + if (provider is EnumProvider && + string.Equals(GetProviderTypeName(provider.Type), providerName, StringComparison.Ordinal)) + { + return true; + } + } + + return false; + } + + private static bool IsStructuredBodyReferenceObject(object value) => + value is ValueExpression || + value is MethodBodyStatement || + value is PropertyBody; + + private static void AddProviderBodyDependencyTypes( + HashSet references, + IReadOnlyList dependencies, + HashSet nodes, + bool includeSimpleNameReferences = false) + { + foreach (var dependency in dependencies) + { + AddProviderBodyDependencyType(references, dependency, nodes, includeSimpleNameReferences); + } + } + + private static void AddProviderBodyDependencyType( + HashSet references, + CSharpType? dependency, + HashSet nodes, + bool includeSimpleNameReferences) + { + if (dependency == null) + { + return; + } + + AddTypeReference(references, dependency, nodes); + if (includeSimpleNameReferences) + { + AddMatchingName(references, dependency.Name, nodes); + } + AddMatchingName(references, $"{dependency.Name}Extensions", nodes); + + foreach (var argument in dependency.Arguments) + { + AddProviderBodyDependencyType(references, argument, nodes, includeSimpleNameReferences); + } + } + + private static IReadOnlyList<(TypeProvider Provider, bool IsSerializationProvider)> GetBodyReferenceProviders(IReadOnlyList providers) + { + var bodyReferenceProviders = new List<(TypeProvider Provider, bool IsSerializationProvider)>(); + foreach (var provider in providers) + { + bodyReferenceProviders.Add((provider, false)); + foreach (var serializationProvider in provider.SerializationProviders) + { + bodyReferenceProviders.Add((serializationProvider, true)); + } + } + + return bodyReferenceProviders; + } + + private static bool IsGeneratedBodyReferenceCandidate(TypeProvider provider, bool isSerializationProvider) + { + if (provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) + { + return true; + } + + return provider.IsClientProvider || + isSerializationProvider || + provider.IncludeGeneratedBodyReferences || + provider.HelperDependencyTypes.Count > 0 || + provider.BodyDependencyTypes.Count > 0; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Candidates.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Candidates.cs new file mode 100644 index 00000000000..da3b9d05cac --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Candidates.cs @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static HashSet GetCustomInternalBoundaryNodes( + ProviderReferenceGraph publicGraph, + HashSet customInternalDeclarations) + { + var boundaryNodes = new HashSet(StringComparer.Ordinal); + foreach (var node in publicGraph.Nodes) + { + if (!publicGraph.References.TryGetValue(node, out var references)) + { + continue; + } + + if (references.Overlaps(customInternalDeclarations)) + { + boundaryNodes.Add(node); + } + } + + return boundaryNodes; + } + + private static HashSet GetPublicizeDeclaredNodes( + IReadOnlyList generatedProviders, + HashSet nodes, + HashSet internalizeDeclaredNodes) + { + var publicizeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, nodes, publicOnly: false); + return publicizeDeclaredNodes; + } + + private static HashSet GetPublicApiTraversalNodes( + HashSet internalizeDeclaredNodes, + HashSet publicizeDeclaredNodes, + HashSet generatedInternalDeclarations, + HashSet generatedImplementationInternalDeclarations) + { + var traversalNodes = new HashSet(StringComparer.Ordinal); + foreach (var node in internalizeDeclaredNodes) + { + if (generatedInternalDeclarations.Contains(node) || + generatedImplementationInternalDeclarations.Contains(node)) + { + continue; + } + + traversalNodes.Add(node); + } + + foreach (var node in publicizeDeclaredNodes) + { + if (!generatedImplementationInternalDeclarations.Contains(node)) + { + traversalNodes.Add(node); + } + } + + return traversalNodes; + } + + private static HashSet GetInternalizeCandidates( + HashSet internalizeDeclaredNodes, + HashSet publicizeReachable, + HashSet customInternalDeclarations, + HashSet customInternalBoundaryNodes, + HashSet publicizeRoots, + HashSet nodes, + IReadOnlyDictionary> references) + { + var candidates = new HashSet(StringComparer.Ordinal); + foreach (var node in internalizeDeclaredNodes) + { + var isNonRootKept = IsKeptName(node, CodeModelGenerator.Instance.NonRootTypes, nodes); + if (!publicizeReachable.Contains(node) || + customInternalDeclarations.Contains(node) || + customInternalBoundaryNodes.Contains(node) && (!publicizeRoots.Contains(node) || isNonRootKept) || + isNonRootKept && + references.TryGetValue(node, out var nodeReferences) && + nodeReferences.Overlaps(customInternalDeclarations)) + { + candidates.Add(node); + } + } + + // If a public non-root type exposes something that must remain internal, make the + // exposing type internal too. That avoids generating public APIs with internal types. + var addedCandidate = true; + while (addedCandidate) + { + addedCandidate = false; + foreach (var node in internalizeDeclaredNodes) + { + var isNonRootKept = IsKeptName(node, CodeModelGenerator.Instance.NonRootTypes, nodes); + if (candidates.Contains(node) || + publicizeRoots.Contains(node) && !isNonRootKept || + !references.TryGetValue(node, out var nodeReferences) || + !nodeReferences.Overlaps(candidates)) + { + continue; + } + + candidates.Add(node); + addedCandidate = true; + } + } + + // Non-root keep entries preserve the declared type/file, but deliberately do not + // root their dependencies. They can still be internalized when needed to avoid + // exposing custom/internal types through a public surface. + RemoveKeptNonRootNames(candidates, nodes, references, customInternalDeclarations, customInternalBoundaryNodes); + return candidates; + } + + private static HashSet GetPublicizeCandidates( + HashSet publicizeDeclaredNodes, + HashSet publicizeReachable, + HashSet customInternalDeclarations, + HashSet customInternalBoundaryNodes, + HashSet internalizeHelperRoots, + HashSet publicizeRootExclusions, + HashSet generatedInternalDeclarations, + HashSet publicizeRoots, + Dictionary> publicApiReferences, + Dictionary> internalizeReferences, + HashSet generatedImplementationInternalDeclarations) + { + var candidates = new HashSet(StringComparer.Ordinal); + foreach (var node in publicizeDeclaredNodes) + { + if (customInternalDeclarations.Contains(node) || + customInternalBoundaryNodes.Contains(node) || + internalizeHelperRoots.Contains(node) || + publicizeRootExclusions.Contains(node) || + !publicizeReachable.Contains(node)) + { + continue; + } + + if (generatedInternalDeclarations.Contains(node) && + !publicizeRoots.Contains(node) && + !HasPublicApiPredecessor(node, publicApiReferences, publicizeReachable, generatedImplementationInternalDeclarations)) + { + continue; + } + + if (!publicizeRoots.Contains(node) && + !HasPublicApiPredecessor(node, internalizeReferences, publicizeReachable, generatedImplementationInternalDeclarations)) + { + continue; + } + + candidates.Add(node); + } + + return candidates; + } + + private static HashSet GetRemovalCandidates( + IReadOnlyList providers, + IReadOnlyList generatedProviders, + ProviderReferenceGraph graph, + HashSet customRemovalRoots, + HashSet generatedDiscriminatorBaseNames) + { + var removeRoots = GetRootNames( + providers, + graph.Nodes, + helperRoots: [], + includeModelFactory: true, + includeAdditionalRoots: true, + includeUnionVariantRoots: true, + includeModelFactorySignatureRoots: true, + publicClientRootsOnly: false); + + removeRoots.UnionWith(customRemovalRoots); + AddMatchingNamesWithSimpleNameSuffix(removeRoots, "ReferenceType", graph.Nodes); + AddKeptNonRootNames(removeRoots, graph.Nodes); + AddCustomCodeExtensionRoots(removeRoots, generatedProviders, graph.Nodes); + AddCustomizationBackedExtensionRoots(removeRoots, graph.Nodes); + AddCustomRequestHeaderExtensionsRoot(removeRoots, generatedProviders, graph.Nodes); + RemoveUnusedRequestHeaderExtensionsRoot(removeRoots, graph.References, providers); + + var removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); + AddDerivedModelReferences(providers, graph.Nodes, graph.References, removeReachableWithoutHelpers, generatedDiscriminatorBaseNames); + removeReachableWithoutHelpers = GetReachableTypes(removeRoots, graph.References); + AddBasePreservedReferences(generatedProviders, graph.Nodes, graph.References, removeReachableWithoutHelpers); + + var removeHelperRoots = GetHelperRootNames(generatedProviders, graph.Nodes, removeReachableWithoutHelpers, graph.References); + removeRoots.UnionWith(removeHelperRoots); + + var removeReachable = GetReachableTypes(removeRoots, graph.References); + AddBasePreservedReferences(generatedProviders, graph.Nodes, graph.References, removeReachable); + + var removeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, graph.Nodes, publicOnly: false); + removeDeclaredNodes.ExceptWith(removeReachable); + return removeDeclaredNodes; + } + + private static void AddKeptNonRootNames(HashSet roots, HashSet nodes) + { + var nonRootTypes = CodeModelGenerator.Instance.NonRootTypes; + foreach (var node in nodes) + { + if (IsKeptName(node, nonRootTypes, nodes)) + { + roots.Add(node); + } + } + } + + private static void RemoveKeptNonRootNames( + HashSet candidates, + HashSet nodes, + IReadOnlyDictionary> references, + HashSet customInternalDeclarations, + HashSet customInternalBoundaryNodes) + { + var nonRootTypes = CodeModelGenerator.Instance.NonRootTypes; + foreach (var node in nodes) + { + if (customInternalDeclarations.Contains(node) || + customInternalBoundaryNodes.Contains(node)) + { + continue; + } + + if (IsKeptName(node, nonRootTypes, nodes) && + !HasCandidateReference(node, candidates, references)) + { + candidates.Remove(node); + } + } + } + + private static bool HasCandidateReference( + string node, + HashSet candidates, + IReadOnlyDictionary> references) + { + if (references.TryGetValue(node, out var nodeReferences)) + { + foreach (var reference in nodeReferences) + { + if (!string.Equals(reference, node, StringComparison.Ordinal) && + candidates.Contains(reference)) + { + return true; + } + } + } + + foreach (var (source, sourceReferences) in references) + { + if (!string.Equals(source, node, StringComparison.Ordinal) && + candidates.Contains(source) && + sourceReferences.Contains(node)) + { + return true; + } + } + + return false; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs new file mode 100644 index 00000000000..ca204b2b5b3 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.CustomCodeRoots.cs @@ -0,0 +1,379 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text.RegularExpressions; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static HashSet GetCustomCodeGeneratedTypeRoots(IReadOnlyList providers, HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + AddCustomCodeViewGeneratedTypeRoot(roots, customCodeView, generatedTypeNames); + AddCustomCodeViewRoots(roots, customCodeView, generatedTypeNames, publicOnly: false); + } + + return roots; + } + + private static HashSet GetCustomCodePublicGeneratedTypeRoots(IReadOnlyList providers, HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + if (!customCodeView.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + AddCustomCodeViewGeneratedTypeRoot(roots, customCodeView, generatedTypeNames); + AddCustomCodeViewRoots(roots, customCodeView, generatedTypeNames, publicOnly: true); + } + + return roots; + } + + private static IEnumerable GetCustomCodeViews(IReadOnlyList providers) + { + var visited = new HashSet(StringComparer.Ordinal); + var modelFactoryCustomCodeView = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value.CustomCodeView; + if (modelFactoryCustomCodeView != null && visited.Add(GetCustomCodeViewIdentity(modelFactoryCustomCodeView))) + { + yield return modelFactoryCustomCodeView; + } + + foreach (var provider in providers) + { + var customCodeView = provider.CustomCodeView; + if (customCodeView == null || !visited.Add(GetCustomCodeViewIdentity(customCodeView))) + { + continue; + } + + yield return customCodeView; + } + + foreach (var customTypeProvider in CodeModelGenerator.Instance.SourceInputModel.GetCustomizationTypeProviders()) + { + if (visited.Add(GetCustomCodeViewIdentity(customTypeProvider))) + { + yield return customTypeProvider; + } + } + } + + private static string GetCustomCodeViewIdentity(TypeProvider customCodeView) => + customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider + ? namedTypeSymbolProvider.MetadataName + : GetProviderTypeName(customCodeView.Type); + + private static void AddCustomRequestHeaderExtensionsRoot(HashSet roots, IReadOnlyList providers, HashSet nodes) + { + if (!HasCustomRequestHeaderExtensionsReference(providers)) + { + return; + } + + AddMatchingNamesWithSimpleNameSuffix(roots, "RequestHeaderExtensions", nodes); + AddMatchingNamesWithSimpleNameSuffix(roots, "RequestHeadersExtensions", nodes); + } + + private static void AddCustomCodeExtensionRoots(HashSet roots, IReadOnlyList providers, HashSet nodes) + { + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + AddMatchingName(roots, $"{GetCustomCodeViewSimpleName(customCodeView)}Extensions", nodes); + } + } + + private static string GetCustomCodeViewSimpleName(TypeProvider customCodeView) => + customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider + ? namedTypeSymbolProvider.MetadataSimpleName + : customCodeView.Type.Name; + + private static void AddCustomCodeViewGeneratedTypeRoot(HashSet roots, TypeProvider customCodeView, HashSet generatedTypeNames) + { + if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) + { + AddMatchingName(roots, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); + return; + } + + AddTypeReference(roots, customCodeView.Type, generatedTypeNames); + } + + private static void AddCustomizationBackedExtensionRoots(HashSet roots, HashSet nodes) + { + foreach (var node in nodes) + { + var simpleName = GetSimpleName(node); + if (!simpleName.EndsWith("Extensions", StringComparison.Ordinal)) + { + continue; + } + + var namespaceName = GetNamespaceName(node); + if (namespaceName == null) + { + continue; + } + + var customTypeName = simpleName.Substring(0, simpleName.Length - "Extensions".Length); + if (CodeModelGenerator.Instance.SourceInputModel.FindForTypeInCustomization(namespaceName, customTypeName) != null) + { + roots.Add(node); + } + } + } + + private static void AddCustomCodeViewRoots(HashSet roots, TypeProvider customCodeView, HashSet generatedTypeNames, bool publicOnly) + { + AddTypeReference(roots, customCodeView.BaseType, generatedTypeNames); + AddProviderBodyDependencyTypes(roots, customCodeView.SignatureDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true); + if (!publicOnly) + { + AddProviderBodyDependencyTypes(roots, customCodeView.BodyDependencyTypes, generatedTypeNames, includeSimpleNameReferences: true); + AddAttributes(roots, customCodeView.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); + AddMatchingName(roots, $"{GetCustomCodeViewSimpleName(customCodeView)}Extensions", generatedTypeNames); + } + + foreach (var implementedType in customCodeView.Implements) + { + AddTypeReference(roots, implementedType, generatedTypeNames); + } + + foreach (var constructor in customCodeView.Constructors) + { + if (publicOnly && !IsPublic(constructor.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(roots, constructor.Signature, generatedTypeNames, serializationProviderNamesByType: null, includeAttributes: !publicOnly); + } + + foreach (var method in customCodeView.Methods) + { + if (publicOnly && !IsPublic(method.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(roots, method.Signature, generatedTypeNames, serializationProviderNamesByType: null, includeAttributes: !publicOnly); + } + + foreach (var property in customCodeView.Properties) + { + if (publicOnly && !IsPublic(property.Modifiers)) + { + continue; + } + + AddTypeReference(roots, property.Type, generatedTypeNames); + AddTypeReference(roots, property.ExplicitInterface, generatedTypeNames); + if (!publicOnly) + { + AddAttributes(roots, property.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); + } + } + + foreach (var field in customCodeView.Fields) + { + if (publicOnly && !IsPublic(field.Modifiers)) + { + continue; + } + + AddTypeReference(roots, field.Type, generatedTypeNames); + if (!publicOnly) + { + AddAttributes(roots, field.Attributes, generatedTypeNames, serializationProviderNamesByType: null, includeArguments: true); + } + } + } + + private static HashSet GetApiBaselineGeneratedTypeRoots(HashSet generatedTypeNames) + { + var roots = new HashSet(StringComparer.Ordinal); + var projectDirectory = CodeModelGenerator.Instance.Configuration.ProjectDirectory; + if (string.IsNullOrEmpty(projectDirectory)) + { + return roots; + } + + var apiDirectory = Path.GetFullPath(Path.Combine(projectDirectory, "..", "api")); + if (!Directory.Exists(apiDirectory)) + { + return roots; + } + + var apiText = string.Join("\n", Directory.GetFiles(apiDirectory, "*.cs", SearchOption.AllDirectories).Select(File.ReadAllText)); + var apiDeclaredTypeNames = GetApiDeclaredTypeNames(apiText); + foreach (var fullName in generatedTypeNames) + { + var simpleName = StripGenericArity(GetSimpleName(fullName)); + var normalizedFullName = StripGenericArity(fullName); + if (!ContainsApiTypeReference(apiText, apiDeclaredTypeNames, normalizedFullName, simpleName)) + { + continue; + } + + roots.Add(fullName); + } + + return roots; + } + + private static HashSet GetApiDeclaredTypeNames(string apiText) + { + var declaredTypeNames = new HashSet(StringComparer.Ordinal); + string? currentNamespace = null; + foreach (var line in apiText.Split('\n')) + { + var namespaceMatch = Regex.Match(line, @"^namespace\s+([\w.]+)\s*\{?\s*$"); + if (namespaceMatch.Success) + { + currentNamespace = namespaceMatch.Groups[1].Value; + continue; + } + + if (currentNamespace == null) + { + continue; + } + + var declarationMatch = Regex.Match(line, @"^ \S.*?\b(class|struct|interface|enum)\s+([A-Za-z_][A-Za-z0-9_]*)(?!\s*<)(?!\w)"); + if (declarationMatch.Success) + { + declaredTypeNames.Add($"{currentNamespace}.{declarationMatch.Groups[2].Value}"); + } + } + + return declaredTypeNames; + } + + private static bool ContainsApiTypeReference(string apiText, HashSet apiDeclaredTypeNames, string fullName, string simpleName) + { + var fullNamePattern = $@"(? GetCustomCodeInternalGeneratedTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) + { + var declarations = new HashSet(StringComparer.Ordinal); + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + if (!customCodeView.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal)) + { + continue; + } + + if (customCodeView is NamedTypeSymbolProvider namedTypeSymbolProvider) + { + AddMatchingName(declarations, namedTypeSymbolProvider.MetadataSimpleName, generatedTypeNames); + } + else + { + AddTypeReference(declarations, customCodeView.Type, generatedTypeNames); + } + } + + return declarations; + } + + private static HashSet GetGeneratedPersistableModelProxyTypeNames(IReadOnlyList providers, HashSet generatedTypeNames) + { + var proxyTypes = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (provider.Attributes.Any(static attribute => IsAttributeNamed(attribute, "PersistableModelProxy"))) + { + AddTypeReference(proxyTypes, provider.Type, generatedTypeNames); + } + } + + return proxyTypes; + } + + private static HashSet GetGeneratedInternalTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) + => GetGeneratedTypeDeclarationsByLastContractAccessibility(providers, generatedTypeNames, TypeSignatureModifiers.Internal); + + private static HashSet GetGeneratedPublicTypeDeclarations(IReadOnlyList providers, HashSet generatedTypeNames) + => GetGeneratedTypeDeclarationsByLastContractAccessibility(providers, generatedTypeNames, TypeSignatureModifiers.Public); + + private static HashSet GetGeneratedTypeDeclarationsByLastContractAccessibility( + IReadOnlyList providers, + HashSet generatedTypeNames, + TypeSignatureModifiers accessibility) + { + var declarations = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (provider.LastContractView?.DeclarationModifiers.HasFlag(accessibility) != true) + { + continue; + } + + AddTypeReference(declarations, provider.Type, generatedTypeNames); + } + + return declarations; + } + + private static HashSet GetGeneratedImplementationInternalTypeDeclarations( + IReadOnlyList providers, + HashSet generatedInternalDeclarations) + { + var implementationDeclarations = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (!provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)) + { + continue; + } + + AddTypeReference(implementationDeclarations, provider.Type, generatedInternalDeclarations); + } + + foreach (var name in generatedInternalDeclarations) + { + if (GetSimpleName(name).StartsWith("Internal", StringComparison.Ordinal)) + { + implementationDeclarations.Add(name); + } + } + + return implementationDeclarations; + } + + private static bool IsGeneratedInternalImplementation(TypeProvider provider) + => provider.RelativeFilePath.Contains( + $"{Path.DirectorySeparatorChar}Generated{Path.DirectorySeparatorChar}Internal{Path.DirectorySeparatorChar}", + StringComparison.Ordinal); + + private static HashSet GetSimpleNames(HashSet names) + { + var simpleNames = new HashSet(StringComparer.Ordinal); + foreach (var name in names) + { + simpleNames.Add(GetSimpleName(name)); + } + + return simpleNames; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.GraphBuilder.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.GraphBuilder.cs new file mode 100644 index 00000000000..cf7569f9c9f --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.GraphBuilder.cs @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static ProviderReferenceGraph BuildGraph(IReadOnlyList generatedProviders, bool publicOnly = false) + { + // Each generated provider becomes a node, and provider metadata supplies the edges: + // inheritance, signatures, properties, fields, nested/serialization providers, attributes, + // and selected implementation dependencies. This avoids parsing generated C# just to + // rediscover generated-to-generated references. + var serializationProviderNamesByType = GetSerializationProviderNamesByType(generatedProviders); + IReadOnlyDictionary? serializationReferenceNamesByType = publicOnly ? null : serializationProviderNamesByType; + var nodes = new HashSet(StringComparer.Ordinal); + var references = new Dictionary>(StringComparer.Ordinal); + foreach (var provider in generatedProviders) + { + var providerName = GetProviderTypeName(provider.Type); + if (nodes.Add(providerName)) + { + references.Add(providerName, new HashSet(StringComparer.Ordinal)); + } + } + + foreach (var provider in generatedProviders) + { + var current = GetProviderTypeName(provider.Type); + AddTypeReference(references[current], provider.Type, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], provider.BaseType, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], provider.DeclaringTypeProvider?.Type, nodes, serializationReferenceNamesByType); + + if (!publicOnly && IsKept(provider.Type, CodeModelGenerator.Instance.NonRootTypes, nodes)) + { + continue; + } + + // Model factory signatures mention many models. The existing Roslyn post-processor + // removes factory methods for unreachable models, so model factory should only + // contribute helper dependencies, not model reachability edges. + if (IsModelFactoryProvider(provider)) + { + continue; + } + + foreach (var implementedType in provider.Implements) + { + AddTypeReference(references[current], implementedType, nodes, serializationReferenceNamesByType); + } + + if (!publicOnly) + { + foreach (var nestedType in provider.NestedTypes) + { + AddTypeReference(references[current], nestedType.Type, nodes, serializationReferenceNamesByType); + } + } + + if (!publicOnly) + { + foreach (var serializationProvider in provider.SerializationProviders) + { + AddTypeReference(references[current], serializationProvider.Type, nodes, serializationReferenceNamesByType); + } + } + + foreach (var property in provider.Properties) + { + if (publicOnly && !IsPublic(property.Modifiers)) + { + continue; + } + + AddTypeReference(references[current], property.Type, nodes, serializationReferenceNamesByType); + AddTypeReference(references[current], property.ExplicitInterface, nodes, serializationReferenceNamesByType); + if (!publicOnly) + { + AddAttributes(references[current], property.Attributes, nodes, serializationReferenceNamesByType, includeArguments: false); + } + } + + foreach (var field in provider.Fields) + { + if (publicOnly && !field.Modifiers.HasFlag(FieldModifiers.Public)) + { + continue; + } + + AddTypeReference(references[current], field.Type, nodes, serializationReferenceNamesByType); + if (!publicOnly) + { + AddAttributes(references[current], field.Attributes, nodes, serializationReferenceNamesByType, includeArguments: false); + } + } + + foreach (var constructor in provider.Constructors) + { + if (publicOnly && !IsPublic(constructor.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(references[current], constructor.Signature, nodes, serializationReferenceNamesByType, includeAttributes: !publicOnly, includeAttributeArguments: false); + } + + foreach (var method in provider.Methods) + { + if (method.IsMethodSuppressed()) + { + continue; + } + + if (publicOnly && !IsPublic(method.Signature.Modifiers)) + { + continue; + } + + AddSignatureReferences(references[current], method.Signature, nodes, serializationReferenceNamesByType, includeAttributes: !publicOnly, includeAttributeArguments: false); + if (!publicOnly) + { + AddTypeReference(references[current], GetCollectionDefinitionType(method), nodes, serializationReferenceNamesByType); + } + } + } + + return new ProviderReferenceGraph(nodes, references); + } + + private static Dictionary GetSerializationProviderNamesByType(IReadOnlyList generatedProviders) + { + var namesByType = new Dictionary>(StringComparer.Ordinal); + foreach (var provider in generatedProviders) + { + if (provider.SerializationProviders.Count == 0) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!namesByType.TryGetValue(providerName, out var serializationProviderNames)) + { + serializationProviderNames = new HashSet(StringComparer.Ordinal); + namesByType.Add(providerName, serializationProviderNames); + } + + foreach (var serializationProvider in provider.SerializationProviders) + { + serializationProviderNames.Add(GetProviderTypeName(serializationProvider.Type)); + } + } + + var result = new Dictionary(StringComparer.Ordinal); + foreach (var (providerName, serializationProviderNames) in namesByType) + { + result.Add(providerName, [.. serializationProviderNames]); + } + + return result; + } + + private static CSharpType? GetCollectionDefinitionType(MethodProvider method) + { + var property = method.GetType().GetProperty("CollectionDefinition"); + return property?.GetValue(method) is TypeProvider collectionDefinition + ? collectionDefinition.Type + : null; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs new file mode 100644 index 00000000000..9dd16e07a52 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.Helpers.cs @@ -0,0 +1,494 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static bool IsKept(CSharpType type, HashSet roots, HashSet nodes) + { + var providerName = GetProviderTypeName(type); + return IsKeptName(providerName, type.Name, roots, nodes); + } + + private static bool IsKeptName(string providerName, HashSet roots, HashSet nodes) + { + return IsKeptName(providerName, StripGenericArity(GetSimpleName(providerName)), roots, nodes); + } + + private static bool IsKeptName(string providerName, string simpleName, HashSet roots, HashSet nodes) + { + if (roots.Contains(providerName) && nodes.Contains(providerName)) + { + return true; + } + + if (!roots.Contains(simpleName)) + { + return false; + } + + var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); + return simpleNameLookup.TryGetValue(simpleName, out var matches) && + matches.Length == 1 && + string.Equals(matches[0], providerName, StringComparison.Ordinal); + } + + private static bool IsClientProviderRoot(TypeProvider provider, bool publicOnly) => + provider.IsClientProvider && + (!publicOnly || !HasApiBaselineDirectory() && provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + + private static bool IsAdditionalRootProvider(TypeProvider provider, HashSet roots, HashSet nodes) + { + if (provider.DeclaringTypeProvider != null || !IsKept(provider.Type, roots, nodes)) + { + return false; + } + + return provider is not ModelProvider && provider is not EnumProvider; + } + + private static bool HasApiBaselineDirectory() + { + var projectDirectory = CodeModelGenerator.Instance.Configuration.ProjectDirectory; + return !string.IsNullOrEmpty(projectDirectory) && + Directory.Exists(Path.GetFullPath(Path.Combine(projectDirectory, "..", "api"))); + } + + private static bool IsModelFactoryProvider(TypeProvider provider) + => provider is ModelFactoryProvider; + + private static HashSet GetHelperRootNames( + IReadOnlyList providers, + HashSet nodes, + HashSet reachableTypes, + IReadOnlyDictionary>? references = null) + { + var roots = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + var isModelFactory = IsModelFactoryProvider(provider); + if (!reachableTypes.Contains(providerName) && !isModelFactory) + { + continue; + } + + AddHelperDependencies(roots, provider.HelperDependencyTypes, nodes, references == null ? null : references[providerName]); + + foreach (var property in provider.Properties) + { + AddInitializationHelperRoot(roots, property.Type, nodes); + AddParameterValidationHelperRoot(roots, property.AsParameter, nodes); + } + + foreach (var field in provider.Fields) + { + AddParameterValidationHelperRoot(roots, field.AsParameter, nodes); + } + + foreach (var constructor in provider.Constructors) + { + foreach (var parameter in constructor.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + } + } + + foreach (var method in provider.Methods) + { + // Only factory methods for reachable models can instantiate collection helpers. + if (isModelFactory && + (method.Signature.ReturnType == null || !reachableTypes.Contains(GetProviderTypeName(method.Signature.ReturnType)))) + { + continue; + } + + foreach (var parameter in method.Signature.Parameters) + { + AddParameterValidationHelperRoot(roots, parameter, nodes); + if (isModelFactory) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, parameter.Type, nodes); + } + } + } + } + + return roots; + } + + private static void AddParameterValidationHelperRoot(HashSet roots, ParameterProvider parameter, HashSet nodes) + { + if (parameter.Validation != ParameterValidationType.None) + { + AddMatchingName(roots, "Argument", nodes); + } + } + + private static void AddHelperDependencies( + HashSet roots, + IReadOnlyList dependencies, + HashSet nodes, + HashSet? referencedNames) + { + foreach (var dependency in dependencies) + { + if (referencedNames == null) + { + AddTypeReference(roots, dependency, nodes); + continue; + } + + var matches = new HashSet(StringComparer.Ordinal); + AddTypeReference(matches, dependency, nodes); + foreach (var match in matches) + { + if (referencedNames.Contains(match)) + { + roots.Add(match); + } + } + } + } + + private static void RemoveUnusedRequestHeaderExtensionsRoot( + HashSet roots, + IReadOnlyDictionary> references, + IReadOnlyList providers) + { + var hasCustomReference = HasCustomRequestHeaderExtensionsReference(providers); + if (hasCustomReference) + { + return; + } + + var unusedRequestHeaderExtensions = new List(); + foreach (var root in roots) + { + if (IsRequestHeadersExtensionsRoot(root) && + !HasExternalReference(root, references)) + { + unusedRequestHeaderExtensions.Add(root); + } + } + + roots.ExceptWith(unusedRequestHeaderExtensions); + } + + private static bool HasExternalReference(string root, IReadOnlyDictionary> references) + { + foreach (var (source, sourceReferences) in references) + { + if (!string.Equals(source, root, StringComparison.Ordinal) && + sourceReferences.Contains(root)) + { + return true; + } + } + + return false; + } + + private static bool IsRequestHeadersExtensionsRoot(string root) => + root.EndsWith(".RequestHeaderExtensions", StringComparison.Ordinal) || + root.EndsWith(".RequestHeadersExtensions", StringComparison.Ordinal); + + private static bool HasCustomRequestHeaderExtensionsReference(IReadOnlyList providers) + { + foreach (var customCodeView in GetCustomCodeViews(providers)) + { + if (customCodeView is NamedTypeSymbolProvider) + { + if (HasRequestHeaderExtensionsDependency(customCodeView.HelperDependencyTypes) || + HasRequestHeaderExtensionsDependency(customCodeView.BodyDependencyTypes) || + HasRequestHeaderExtensionsDependency(customCodeView.SignatureDependencyTypes)) + { + return true; + } + + continue; + } + + if (HasRequestHeaderExtensionsDependency(customCodeView.HelperDependencyTypes) || + HasRequestHeaderExtensionsDependency(customCodeView.BodyDependencyTypes) || + HasRequestHeaderExtensionsMethodDependency(customCodeView.Methods) || + HasRequestHeaderExtensionsPropertyDependency(customCodeView.Properties) || + HasRequestHeaderExtensionsFieldDependency(customCodeView.Fields)) + { + return true; + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsDependency(IEnumerable dependencies) + { + foreach (var dependency in dependencies) + { + if (IsRequestHeaderExtensionsDependency(dependency)) + { + return true; + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsMethodDependency(IReadOnlyList methods) + { + foreach (var method in methods) + { + if (IsRequestHeaderExtensionsDependency(method.Signature.ReturnType)) + { + return true; + } + + foreach (var parameter in method.Signature.Parameters) + { + if (IsRequestHeaderExtensionsDependency(parameter.Type)) + { + return true; + } + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsPropertyDependency(IReadOnlyList properties) + { + foreach (var property in properties) + { + if (IsRequestHeaderExtensionsDependency(property.Type)) + { + return true; + } + } + + return false; + } + + private static bool HasRequestHeaderExtensionsFieldDependency(IReadOnlyList fields) + { + foreach (var field in fields) + { + if (IsRequestHeaderExtensionsDependency(field.Type)) + { + return true; + } + } + + return false; + } + + private static bool IsRequestHeaderExtensionsDependency(string name) + => string.Equals(name, "RequestHeaderExtensions", StringComparison.Ordinal) || + string.Equals(name, "SetDelimited", StringComparison.Ordinal); + + private static bool IsRequestHeaderExtensionsDependency(CSharpType? type) + { + if (type == null) + { + return false; + } + + if (IsRequestHeaderExtensionsDependency(type.Name)) + { + return true; + } + + foreach (var argument in type.Arguments) + { + if (IsRequestHeaderExtensionsDependency(argument)) + { + return true; + } + } + + return false; + } + + private static bool IsSerializationProvider(TypeProvider provider) + { + var relativePath = provider.RelativeFilePath.Replace('\\', '/'); + return relativePath.EndsWith(".Serialization.cs", StringComparison.Ordinal) || + relativePath.EndsWith(".Serialization.Multipart.cs", StringComparison.Ordinal); + } + + private static void AddInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + var initializationType = type.PropertyInitializationType; + if (!string.Equals(initializationType.FullyQualifiedName, type.FullyQualifiedName, StringComparison.Ordinal)) + { + AddMatchingName(roots, initializationType.Name, nodes); + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.ListInitializationType, nodes); + } + + if (type.IsDictionary) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType, nodes); + } + + foreach (var argument in type.Arguments) + { + AddInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddModelFactoryCollectionInitializationHelperRoot(HashSet roots, CSharpType? type, HashSet nodes) + { + if (type == null) + { + return; + } + + if (type is { IsList: true, IsReadOnlyMemory: false }) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.ListInitializationType, nodes); + } + + if (type.IsDictionary) + { + AddTypeReference(roots, CodeModelGenerator.Instance.TypeFactory.DictionaryInitializationType, nodes); + } + + foreach (var argument in type.Arguments) + { + AddModelFactoryCollectionInitializationHelperRoot(roots, argument, nodes); + } + } + + private static void AddMatchingName(HashSet target, string name, HashSet nodes) + { + if (nodes.Contains(name)) + { + target.Add(name); + return; + } + + var simpleNameLookup = _simpleNameLookupCache.GetValue(nodes, BuildSimpleNameLookup); + if (!simpleNameLookup.TryGetValue(name, out var matches)) + { + return; + } + + foreach (var match in matches) + { + target.Add(match); + } + } + + private static void AddMatchingNamesWithSimpleNameSuffix(HashSet target, string suffix, HashSet nodes) + { + foreach (var node in nodes) + { + if (GetSimpleName(node).EndsWith(suffix, StringComparison.Ordinal)) + { + target.Add(node); + } + } + } + + private static Dictionary BuildSimpleNameLookup(HashSet nodes) + { + var lookup = new Dictionary>(StringComparer.Ordinal); + foreach (var node in nodes) + { + var simpleName = StripGenericArity(GetSimpleName(node)); + if (!lookup.TryGetValue(simpleName, out var matchingNodes)) + { + matchingNodes = []; + lookup.Add(simpleName, matchingNodes); + } + + matchingNodes.Add(node); + } + + var result = new Dictionary(StringComparer.Ordinal); + foreach (var (simpleName, matchingNodes) in lookup) + { + result.Add(simpleName, [.. matchingNodes]); + } + + return result; + } + + private static HashSet GetReachableTypes(HashSet roots, IReadOnlyDictionary> references) + { + return GetReachableTypes(roots, references, expandableNodes: null); + } + + private static HashSet GetReachableTypes( + HashSet roots, + IReadOnlyDictionary> references, + HashSet? expandableNodes) + { + var reachable = new HashSet(StringComparer.Ordinal); + var queue = new Queue(roots); + while (queue.Count > 0) + { + var current = queue.Dequeue(); + if (!reachable.Add(current)) + { + continue; + } + + if (expandableNodes != null && !expandableNodes.Contains(current)) + { + continue; + } + + if (!references.TryGetValue(current, out var children)) + { + continue; + } + + foreach (var child in children) + { + queue.Enqueue(child); + } + } + + return reachable; + } + + private static bool HasPublicApiPredecessor( + string name, + IReadOnlyDictionary> references, + HashSet publicizeReachable, + HashSet generatedImplementationInternalDeclarations) + { + foreach (var (owner, children) in references) + { + if (!publicizeReachable.Contains(owner) || + string.Equals(owner, name, StringComparison.Ordinal) || + generatedImplementationInternalDeclarations.Contains(owner) || + !children.Contains(name)) + { + continue; + } + + return true; + } + + return false; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.ReferenceTraversal.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.ReferenceTraversal.cs new file mode 100644 index 00000000000..77327749735 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.ReferenceTraversal.cs @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static Dictionary> CloneReferences(IReadOnlyDictionary> references) + { + var clone = new Dictionary>(StringComparer.Ordinal); + foreach (var (name, referencedNames) in references) + { + clone.Add(name, new HashSet(referencedNames, StringComparer.Ordinal)); + } + + return clone; + } + + private static void AddDerivedModelReferences( + IReadOnlyList providers, + HashSet nodes, + Dictionary> references, + HashSet publicBaseModels, + HashSet generatedDiscriminatorBaseNames) + { + var modelProviders = new List(); + var discriminatorProviders = new List(); + var discriminatorBaseNames = new HashSet(StringComparer.Ordinal); + foreach (var provider in providers) + { + if (provider is not ModelProvider modelProvider || + !modelProvider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + modelProviders.Add(modelProvider); + + if (modelProvider.DiscriminatorProperty != null) + { + discriminatorBaseNames.Add(GetProviderTypeName(modelProvider.Type)); + } + + if (!modelProvider.IsUnknownDiscriminatorModel && + (modelProvider.DiscriminatorProperty != null || modelProvider.DiscriminatorValue != null)) + { + discriminatorProviders.Add(modelProvider); + } + } + + discriminatorBaseNames.UnionWith(generatedDiscriminatorBaseNames); + var addedReference = true; + while (addedReference) + { + addedReference = false; + foreach (var provider in discriminatorProviders) + { + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName)) + { + continue; + } + + if (!publicBaseModels.Contains(providerName)) + { + continue; + } + + foreach (var derivedModel in provider.DerivedModels) + { + if (derivedModel.IsUnknownDiscriminatorModel || + !derivedModel.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var before = references[providerName].Count; + AddTypeReference(references[providerName], derivedModel.Type, nodes); + var derivedName = GetProviderTypeName(derivedModel.Type); + if (nodes.Contains(derivedName) && publicBaseModels.Add(derivedName) || references[providerName].Count != before) + { + addedReference = true; + } + } + } + + foreach (var provider in modelProviders) + { + if (provider.IsUnknownDiscriminatorModel || + !provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName)) + { + continue; + } + + var baseTypeName = provider.BaseType == null ? null : GetProviderTypeName(provider.BaseType); + if (baseTypeName == null || + !discriminatorBaseNames.Contains(baseTypeName) || + !nodes.Contains(baseTypeName) || + !publicBaseModels.Contains(baseTypeName)) + { + continue; + } + + var before = references[baseTypeName].Count; + references[baseTypeName].Add(providerName); + if (publicBaseModels.Add(providerName) || references[baseTypeName].Count != before) + { + addedReference = true; + } + } + } + } + + private static void AddBasePreservedReferences( + IReadOnlyList providers, + HashSet nodes, + IReadOnlyDictionary> references, + HashSet reachableTypes) + { + var basePreservedRoots = new HashSet(StringComparer.Ordinal); + var addedRoot = true; + while (addedRoot) + { + addedRoot = false; + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + if (!nodes.Contains(providerName) || reachableTypes.Contains(providerName) || basePreservedRoots.Contains(providerName)) + { + continue; + } + + var baseTypeName = provider.BaseType == null ? null : GetProviderTypeName(provider.BaseType); + if (baseTypeName == null || !reachableTypes.Contains(baseTypeName)) + { + continue; + } + + if (basePreservedRoots.Add(providerName)) + { + addedRoot = true; + } + } + + if (addedRoot) + { + reachableTypes.UnionWith(GetReachableTypes(basePreservedRoots, references)); + } + } + } + + private static IReadOnlyList GetGeneratedProviders(IReadOnlyList providers) + { + var generatedProviders = new List(); + foreach (var provider in providers) + { + AddGeneratedProvider(generatedProviders, provider); + } + + return generatedProviders; + } + + private static void AddGeneratedProvider(List generatedProviders, TypeProvider provider) + { + generatedProviders.Add(provider); + foreach (var nestedType in provider.NestedTypes) + { + AddGeneratedProvider(generatedProviders, nestedType); + } + + foreach (var serializationProvider in provider.SerializationProviders) + { + AddGeneratedProvider(generatedProviders, serializationProvider); + } + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.RootSelection.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.RootSelection.cs new file mode 100644 index 00000000000..26dea9955f8 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.RootSelection.cs @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.IO; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static HashSet GetRootNames( + IReadOnlyList providers, + HashSet nodes, + HashSet helperRoots, + bool includeModelFactory, + bool includeAdditionalRoots, + bool includeUnionVariantRoots, + bool includeModelFactorySignatureRoots, + bool publicClientRootsOnly) + { + var generator = CodeModelGenerator.Instance; + var roots = new HashSet(StringComparer.Ordinal); + var modelFactoryName = GetProviderTypeName(generator.OutputLibrary.ModelFactory.Value.Type); + + foreach (var provider in providers) + { + var name = GetProviderTypeName(provider.Type); + if (IsClientProviderRoot(provider, publicClientRootsOnly) || + includeAdditionalRoots && IsAdditionalRootProvider(provider, generator.AdditionalRootTypes, nodes) || + includeModelFactory && string.Equals(name, modelFactoryName, StringComparison.Ordinal) || + includeModelFactory && helperRoots.Contains(name)) + { + roots.Add(name); + } + } + + if (includeModelFactorySignatureRoots) + { + AddLastContractModelFactorySignatureRoots(providers, roots, nodes); + } + + if (!includeUnionVariantRoots) + { + return roots; + } + + AddUnionVariantRoots(roots, providers, nodes); + + return roots; + } + + private static void AddLastContractModelFactorySignatureRoots(IReadOnlyList providers, HashSet roots, HashSet nodes) + { + foreach (var provider in providers) + { + if (!IsModelFactoryProvider(provider)) + { + continue; + } + + foreach (var method in provider.LastContractView?.Methods ?? []) + { + if (!method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Public) || + IsImplementationOnlyModelFactoryMethod(method)) + { + continue; + } + + AddTypeReference(roots, method.Signature.ReturnType, nodes); + foreach (var parameter in method.Signature.Parameters) + { + AddTypeReference(roots, parameter.Type, nodes); + } + } + } + } + + private static void AddUnionVariantRoots(HashSet roots, IReadOnlyList providers, HashSet nodes) + { + var unionVariantTypesToKeep = CodeModelGenerator.Instance.TypeFactory.UnionVariantTypesToKeep; + foreach (var provider in GetGeneratedProviders(providers)) + { + if (provider is not ModelProvider || + !unionVariantTypesToKeep.Contains(provider.Type.Name) || + string.Equals(provider.Type.Namespace, "TypeSpec.Http", StringComparison.Ordinal)) + { + continue; + } + + AddMatchingName(roots, GetProviderTypeName(provider.Type), nodes); + } + } + + private static bool ShouldUseUnionVariantFallbackRoots() => + !HasApiBaselineDirectory() && + CodeModelGenerator.Instance.SourceInputModel.LastContract == null; + + private static bool IsImplementationOnlyModelFactoryMethod(MethodProvider method) + { + var returnType = method.Signature.ReturnType; + if (returnType == null) + { + return true; + } + + var returnTypeName = GetSimpleName(GetProviderTypeName(returnType)); + return returnTypeName.StartsWith("Paged", StringComparison.Ordinal) || + returnTypeName.EndsWith("Request", StringComparison.Ordinal); + } + + private static void RemoveMethodsFromModelFactory(HashSet namesToRemove) + { + if (namesToRemove.Count == 0) + { + return; + } + + var modelFactory = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value; + _preWriteModelFactory = modelFactory; + _preWriteModelFactoryMethods ??= [.. modelFactory.Methods]; + var methodsToKeep = new List(); + foreach (var method in modelFactory.Methods) + { + if (!namesToRemove.Contains(method.Signature.Name)) + { + methodsToKeep.Add(method); + } + } + + modelFactory.Update(methods: methodsToKeep); + } + + private static HashSet GetPostProcessorDeclaredNodes(IReadOnlyList providers, HashSet nodes, bool publicOnly) + { + var generator = CodeModelGenerator.Instance; + var declaredNodes = new HashSet(StringComparer.Ordinal); + foreach (var provider in GetGeneratedProviders(providers)) + { + if (IsModelFactoryProvider(provider)) + { + continue; + } + + if (publicOnly && !provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)) + { + continue; + } + + var name = GetProviderTypeName(provider.Type); + if (!nodes.Contains(name)) + { + continue; + } + + declaredNodes.Add(name); + } + + return declaredNodes; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.TypeReferenceCollector.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.TypeReferenceCollector.cs new file mode 100644 index 00000000000..e6c4c9fd958 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.TypeReferenceCollector.cs @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using Microsoft.TypeSpec.Generator.Expressions; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Statements; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static void AddSignatureReferences( + HashSet references, + MethodSignatureBase signature, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType, + bool includeAttributes = true, + bool includeAttributeArguments = true) + { + AddTypeReference(references, signature.ReturnType, nodes, serializationProviderNamesByType); + if (includeAttributes) + { + AddAttributes(references, signature.Attributes, nodes, serializationProviderNamesByType, includeAttributeArguments); + } + + foreach (var parameter in signature.Parameters) + { + AddTypeReference(references, parameter.Type, nodes, serializationProviderNamesByType); + if (includeAttributes) + { + AddAttributes(references, parameter.Attributes, nodes, serializationProviderNamesByType, includeAttributeArguments); + } + } + + if (signature is MethodSignature methodSignature) + { + AddTypeReference(references, methodSignature.ExplicitInterface, nodes, serializationProviderNamesByType); + if (methodSignature.GenericArguments != null) + { + foreach (var genericArgument in methodSignature.GenericArguments) + { + AddTypeReference(references, genericArgument, nodes, serializationProviderNamesByType); + } + } + + if (methodSignature.GenericParameterConstraints != null) + { + foreach (var constraint in methodSignature.GenericParameterConstraints) + { + AddTypeReference(references, constraint.Type, nodes, serializationProviderNamesByType); + } + } + } + + if (signature is ConstructorSignature constructorSignature) + { + AddTypeReference(references, constructorSignature.Type, nodes, serializationProviderNamesByType); + } + } + + private static void AddAttributes( + HashSet references, + IReadOnlyList attributes, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType, + bool includeArguments) + { + foreach (var attribute in attributes) + { + AddTypeReference(references, attribute.Type, nodes, serializationProviderNamesByType); + if (!includeArguments) + { + continue; + } + + foreach (var argument in attribute.Arguments) + { + AddAttributeArgumentReference(references, argument, nodes, serializationProviderNamesByType); + } + + foreach (var (_, argument) in attribute.PositionalArguments) + { + AddAttributeArgumentReference(references, argument, nodes, serializationProviderNamesByType); + } + } + } + + private static bool IsAttributeNamed(AttributeStatement attribute, string name) + => string.Equals(attribute.Type.Name, name, StringComparison.Ordinal) || + string.Equals(attribute.Type.Name, $"{name}Attribute", StringComparison.Ordinal); + + private static void AddAttributeArgumentReference( + HashSet references, + ValueExpression argument, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType) + { + if (argument is TypeOfExpression typeOf) + { + AddTypeReference(references, typeOf.Type, nodes, serializationProviderNamesByType); + AddMatchingName(references, typeOf.Type.Name, nodes); + } + } + + private static void AddTypeReference( + HashSet references, + CSharpType? type, + HashSet nodes, + IReadOnlyDictionary? serializationProviderNamesByType = null) + { + if (type == null) + { + return; + } + + if (type.IsArray) + { + AddTypeReference(references, type.ElementType, nodes, serializationProviderNamesByType); + return; + } + + var providerTypeName = GetProviderTypeName(type); + if (nodes.Contains(providerTypeName)) + { + references.Add(providerTypeName); + if (serializationProviderNamesByType != null && serializationProviderNamesByType.TryGetValue(providerTypeName, out var serializationProviderNames)) + { + foreach (var serializationProviderName in serializationProviderNames) + { + references.Add(serializationProviderName); + } + } + } + + AddTypeReference(references, type.BaseType, nodes, serializationProviderNamesByType); + AddTypeReference(references, type.DeclaringType, nodes, serializationProviderNamesByType); + foreach (var argument in type.Arguments) + { + AddTypeReference(references, argument, nodes, serializationProviderNamesByType); + } + } + + private static string GetSimpleName(string fullyQualifiedName) + { + var lastDot = fullyQualifiedName.LastIndexOf('.'); + return lastDot < 0 ? fullyQualifiedName : fullyQualifiedName.Substring(lastDot + 1); + } + + private static string? GetNamespaceName(string fullyQualifiedName) + { + var lastDot = fullyQualifiedName.LastIndexOf('.'); + return lastDot < 0 ? null : fullyQualifiedName.Substring(0, lastDot); + } + + private static string GetProviderTypeName(CSharpType type) + { + var name = type.Arguments.Count > 0 && !type.Name.Contains('`', StringComparison.Ordinal) + ? $"{type.Name}`{type.Arguments.Count}" + : type.Name; + return string.IsNullOrEmpty(type.Namespace) ? name : $"{type.Namespace}.{name}"; + } + + private static string StripGenericArity(string name) + { + var tick = name.IndexOf('`'); + return tick < 0 ? name : name.Substring(0, tick); + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs new file mode 100644 index 00000000000..c89186bc345 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/ReferenceMap/ProviderReferenceMapAnalyzer.cs @@ -0,0 +1,243 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text.RegularExpressions; +using Microsoft.TypeSpec.Generator.Expressions; +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Statements; + +namespace Microsoft.TypeSpec.Generator +{ + internal static partial class ProviderReferenceMapAnalyzer + { + private static ProviderReferenceMapResult? _latestResult; + private static readonly ConditionalWeakTable, Dictionary> _simpleNameLookupCache = new(); + private static TypeProvider? _preWriteModelFactory; + private static MethodProvider[]? _preWriteModelFactoryMethods; + + public static ProviderReferenceMapResult? LatestResult => _latestResult; + public static bool PreWriteAccessibilityApplied { get; private set; } + + public static bool ShouldWriteProvider(TypeProvider provider) => + _latestResult?.RemoveCandidates.Contains(GetProviderTypeName(provider.Type)) != true; + + public static void ResetPreWriteAccessibility() + { + RestorePreWriteModelFactoryMethods(); + _latestResult = null; + PreWriteAccessibilityApplied = false; + } + + public static void ApplyPreWriteAccessibility(IReadOnlyList providers) + { + PreWriteAccessibilityApplied = false; + if (Configuration.UnreferencedTypesHandling == Configuration.UnreferencedTypesHandlingOption.KeepAll) + { + return; + } + + // Accessibility has to be adjusted before files are written. Roslyn can remove files + // later, but it cannot safely change provider declarations or model factory signatures. + var (internalizeCandidates, publicizeCandidates) = GetPreWriteAccessibilityCandidates(providers); + foreach (var provider in GetGeneratedProviders(providers)) + { + var providerName = GetProviderTypeName(provider.Type); + if (internalizeCandidates.Contains(providerName)) + { + provider.PreserveXmlDocs(); + provider.Update(modifiers: MakeInternal(provider.DeclarationModifiers)); + } + else if (publicizeCandidates.Contains(providerName) && !IsGeneratedInternalImplementation(provider)) + { + provider.Update(modifiers: MakePublic(provider.DeclarationModifiers)); + } + } + + RemoveMethodsFromModelFactory(GetSimpleNames(internalizeCandidates)); + PreWriteAccessibilityApplied = true; + } + + public static void RestorePreWriteModelFactoryMethods() + { + if (_preWriteModelFactory == null || _preWriteModelFactoryMethods == null) + { + return; + } + + _preWriteModelFactory.Update(methods: _preWriteModelFactoryMethods); + _preWriteModelFactory = null; + _preWriteModelFactoryMethods = null; + } + + public static void Analyze(IReadOnlyList providers) + { + var generatedProviders = GetGeneratedProviders(providers); + + // Build two graphs from provider metadata: the full implementation graph for removal, + // and the public-surface graph for accessibility decisions. + var graph = BuildGraph(generatedProviders); + var publicGraph = BuildGraph(generatedProviders, publicOnly: true); + + var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(generatedProviders, graph.Nodes); + var apiBaselineGeneratedTypeRoots = GetApiBaselineGeneratedTypeRoots(graph.Nodes); + customPublicRoots.UnionWith(apiBaselineGeneratedTypeRoots); + var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarations(generatedProviders, graph.Nodes); + customPublicRoots.UnionWith(generatedPublicDeclarations); + var customCodeRemovalRoots = GetCustomCodeGeneratedTypeRoots(generatedProviders, graph.Nodes); + var customRemovalRoots = new HashSet(customCodeRemovalRoots, StringComparer.Ordinal); + customRemovalRoots.UnionWith(apiBaselineGeneratedTypeRoots); + customRemovalRoots.UnionWith(generatedPublicDeclarations); + var customInternalDeclarations = GetCustomCodeInternalGeneratedTypeDeclarations(generatedProviders, graph.Nodes); + var generatedInternalDeclarations = GetGeneratedInternalTypeDeclarations(generatedProviders, graph.Nodes); + + // Helper types are rooted after an initial reachability pass so unused infrastructure + // such as change-tracking dictionaries can still be removed when no reachable type needs them. + var generatedDiscriminatorBaseNames = GetGeneratedPersistableModelProxyTypeNames(generatedProviders, publicGraph.Nodes); + var (internalizeCandidates, publicizeCandidates, _) = GetAccessibilityCandidates( + providers, + generatedProviders, + graph, + publicGraph, + customPublicRoots, + customInternalDeclarations, + generatedInternalDeclarations, + generatedDiscriminatorBaseNames); + + // Body-only generated dependencies are needed to avoid deleting helper files, but they do + // not contribute to public API reachability for internalization. + AddGeneratedBodyReferences(providers, graph); + var removeCandidates = GetRemovalCandidates( + providers, + generatedProviders, + graph, + customRemovalRoots, + generatedDiscriminatorBaseNames); + + _latestResult = new ProviderReferenceMapResult( + internalizeCandidates, + publicizeCandidates, + removeCandidates); + RemoveMethodsFromModelFactory(GetSimpleNames(removeCandidates)); + } + + private static (HashSet InternalizeCandidates, HashSet PublicizeCandidates) GetPreWriteAccessibilityCandidates(IReadOnlyList providers) + { + var generatedProviders = GetGeneratedProviders(providers); + var graph = BuildGraph(generatedProviders); + var publicGraph = BuildGraph(generatedProviders, publicOnly: true); + var customPublicRoots = GetCustomCodePublicGeneratedTypeRoots(generatedProviders, graph.Nodes); + var apiBaselineGeneratedTypeRoots = GetApiBaselineGeneratedTypeRoots(graph.Nodes); + customPublicRoots.UnionWith(apiBaselineGeneratedTypeRoots); + var generatedPublicDeclarations = GetGeneratedPublicTypeDeclarations(generatedProviders, graph.Nodes); + customPublicRoots.UnionWith(generatedPublicDeclarations); + var customInternalDeclarations = GetCustomCodeInternalGeneratedTypeDeclarations(generatedProviders, graph.Nodes); + var generatedInternalDeclarations = GetGeneratedInternalTypeDeclarations(generatedProviders, graph.Nodes); + var generatedDiscriminatorBaseNames = new HashSet(StringComparer.Ordinal); + + var (internalizeCandidates, publicizeCandidates, _) = GetAccessibilityCandidates( + providers, + generatedProviders, + graph, + publicGraph, + customPublicRoots, + customInternalDeclarations, + generatedInternalDeclarations, + generatedDiscriminatorBaseNames); + + return (internalizeCandidates, publicizeCandidates); + } + + private static (HashSet InternalizeCandidates, HashSet PublicizeCandidates, HashSet InternalizeHelperRoots) GetAccessibilityCandidates( + IReadOnlyList providers, + IReadOnlyList generatedProviders, + ProviderReferenceGraph graph, + ProviderReferenceGraph publicGraph, + HashSet customPublicRoots, + HashSet customInternalDeclarations, + HashSet generatedInternalDeclarations, + HashSet generatedDiscriminatorBaseNames) + { + var internalizeReferences = CloneReferences(publicGraph.References); + + // Start from public client and custom/public API roots. Anything public-reachable can + // stay public unless it crosses a custom/internal boundary. + var internalizeRoots = GetRootNames(providers, graph.Nodes, helperRoots: [], includeModelFactory: false, includeAdditionalRoots: true, includeUnionVariantRoots: false, includeModelFactorySignatureRoots: true, publicClientRootsOnly: true); + if (ShouldUseUnionVariantFallbackRoots()) + { + AddUnionVariantRoots(internalizeRoots, providers, graph.Nodes); + } + + var generatedPublicReachable = GetReachableTypes(internalizeRoots, internalizeReferences); + AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, generatedPublicReachable, generatedDiscriminatorBaseNames); + internalizeRoots.UnionWith(customPublicRoots); + var internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); + AddDerivedModelReferences(providers, publicGraph.Nodes, internalizeReferences, internalizeReachableWithoutHelpers, generatedDiscriminatorBaseNames); + internalizeReachableWithoutHelpers = GetReachableTypes(internalizeRoots, internalizeReferences); + var publicizeRoots = new HashSet(internalizeRoots, StringComparer.Ordinal); + var publicApiReferences = CloneReferences(publicGraph.References); + var internalizeHelperRoots = GetHelperRootNames(generatedProviders, graph.Nodes, internalizeReachableWithoutHelpers); + internalizeRoots.UnionWith(internalizeHelperRoots); + var internalizeDeclaredNodes = GetPostProcessorDeclaredNodes(generatedProviders, graph.Nodes, publicOnly: true); + var customInternalBoundaryNodes = GetCustomInternalBoundaryNodes(publicGraph, customInternalDeclarations); + var publicizeDeclaredNodes = GetPublicizeDeclaredNodes(generatedProviders, graph.Nodes, internalizeDeclaredNodes); + var generatedImplementationInternalDeclarations = GetGeneratedImplementationInternalTypeDeclarations(generatedProviders, generatedInternalDeclarations); + var publicApiTraversalNodes = GetPublicApiTraversalNodes( + internalizeDeclaredNodes, + publicizeDeclaredNodes, + generatedInternalDeclarations, + generatedImplementationInternalDeclarations); + var publicizeReachable = GetReachableTypes(publicizeRoots, internalizeReferences, publicApiTraversalNodes); + var internalizeCandidates = GetInternalizeCandidates( + internalizeDeclaredNodes, + publicizeReachable, + customInternalDeclarations, + customInternalBoundaryNodes, + publicizeRoots, + graph.Nodes, + internalizeReferences); + var publicizeRootExclusions = GetRootNames( + providers, + graph.Nodes, + helperRoots: [], + includeModelFactory: true, + includeAdditionalRoots: true, + includeUnionVariantRoots: true, + includeModelFactorySignatureRoots: false, + publicClientRootsOnly: true); + var publicizeCandidates = GetPublicizeCandidates( + publicizeDeclaredNodes, + publicizeReachable, + customInternalDeclarations, + customInternalBoundaryNodes, + internalizeHelperRoots, + publicizeRootExclusions, + generatedInternalDeclarations, + publicizeRoots, + publicApiReferences, + internalizeReferences, + generatedImplementationInternalDeclarations); + return (internalizeCandidates, publicizeCandidates, internalizeHelperRoots); + } + + private static bool IsPublic(MethodSignatureModifiers modifiers) => modifiers.HasFlag(MethodSignatureModifiers.Public); + private static bool IsPublic(FieldModifiers modifiers) => modifiers.HasFlag(FieldModifiers.Public); + + private static TypeSignatureModifiers MakeInternal(TypeSignatureModifiers modifiers) + => (modifiers & ~(TypeSignatureModifiers.Public | TypeSignatureModifiers.Private | TypeSignatureModifiers.Protected)) | TypeSignatureModifiers.Internal; + + private static TypeSignatureModifiers MakePublic(TypeSignatureModifiers modifiers) + => (modifiers & ~(TypeSignatureModifiers.Internal | TypeSignatureModifiers.Private | TypeSignatureModifiers.Protected)) | TypeSignatureModifiers.Public; + + private sealed record ProviderReferenceGraph( + HashSet Nodes, + Dictionary> References); + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/SourceInput/SourceInputModel.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/SourceInput/SourceInputModel.cs index a329166ee4b..4792e781338 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/SourceInput/SourceInputModel.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/SourceInput/SourceInputModel.cs @@ -22,6 +22,7 @@ public class SourceInputModel public ApiCompatBaseline ApiCompatBaseline { get; } private readonly Lazy> _nameMap; + private readonly Lazy> _customizationTypeProviders; public SourceInputModel(Compilation? customization, Compilation? lastContract) : this(customization, lastContract, ApiCompatBaseline.Empty) @@ -35,6 +36,7 @@ public SourceInputModel(Compilation? customization, Compilation? lastContract, A ApiCompatBaseline = apiCompatBaseline ?? ApiCompatBaseline.Empty; _nameMap = new(PopulateNameMap); + _customizationTypeProviders = new(PopulateCustomizationTypeProviders); } private IReadOnlyDictionary PopulateNameMap() @@ -70,6 +72,30 @@ private IReadOnlyDictionary PopulateNameMap() return FindTypeInCompilation(LastContract, ns, name, true, declaringTypeName, includeInternal: false); } + private IReadOnlyList PopulateCustomizationTypeProviders() + { + var providers = new List(); + if (Customization == null) + { + return providers; + } + + foreach (IModuleSymbol module in Customization.Assembly.Modules) + { + foreach (var type in SourceInputHelper.GetSymbols(module.GlobalNamespace)) + { + if (type is INamedTypeSymbol namedTypeSymbol) + { + providers.Add(new NamedTypeSymbolProvider(namedTypeSymbol, Customization)); + } + } + } + + return providers; + } + + internal IReadOnlyList GetCustomizationTypeProviders() => _customizationTypeProviders.Value; + private TypeProvider? FindTypeInCompilation( Compilation? compilation, string ns, diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Statements/XmlDocStatement.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Statements/XmlDocStatement.cs index cfca3c2db7b..110f92af48e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Statements/XmlDocStatement.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Statements/XmlDocStatement.cs @@ -217,6 +217,8 @@ public static string EscapeLine(string s) "", "", "", + "", + "", "", "", " _optionalProvider ??= new(); + private Dictionary InputTypeToModelProvider { get; } = []; public IDictionary CSharpTypeMap { get; } = new Dictionary(CSharpType.IgnoreNullableComparer); @@ -200,11 +203,6 @@ protected internal TypeFactory() if (modelProvider != null) { - if (model.Access == "public") - { - CodeModelGenerator.Instance.AddTypeToKeep(modelProvider); - } - CSharpTypeMap[modelProvider.Type] = modelProvider; TypeProvidersByName[modelProvider.Type.Name] = modelProvider; } @@ -299,11 +297,6 @@ protected virtual ModelFactoryProvider CreateModelFactoryCore(IEnumerable enumProvider, }; - if (enumType.Access == "public") - { - CodeModelGenerator.Instance.AddTypeToKeep(enumProvider); - } - EnumCache.Add(enumCacheKey, enumProvider); if (enumProvider != null) @@ -500,6 +493,11 @@ inputProperty.Type is InputArrayType && /// public virtual CSharpType DictionaryInitializationType => ChangeTrackingDictionaryProvider.Type; + /// + /// The type used to represent optional values in generated helper code. + /// + public virtual CSharpType OptionalType => OptionalProvider.Type; + /// /// Returns the serialization type providers for the given model type provider. /// diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Utilities/TypeSymbolExtensions.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Utilities/TypeSymbolExtensions.cs index 8af371a063c..c7915847be6 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Utilities/TypeSymbolExtensions.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Utilities/TypeSymbolExtensions.cs @@ -43,6 +43,9 @@ public static bool IsSameType(this INamedTypeSymbol symbol, CSharpType type) } public static CSharpType GetCSharpType(this ITypeSymbol typeSymbol) + => GetCSharpType(typeSymbol, new HashSet(SymbolEqualityComparer.Default)); + + private static CSharpType GetCSharpType(this ITypeSymbol typeSymbol, HashSet visited) { var fullyQualifiedName = GetFullyQualifiedName(typeSymbol); var namedTypeSymbol = typeSymbol as INamedTypeSymbol; @@ -55,20 +58,20 @@ public static CSharpType GetCSharpType(this ITypeSymbol typeSymbol) if (namedTypeSymbol?.ConstructedFrom.SpecialType == SpecialType.System_Nullable_T && namedTypeSymbol.TypeArguments.Length == 1) { - var underlying = GetCSharpType(namedTypeSymbol.TypeArguments[0]); + var underlying = GetCSharpType(namedTypeSymbol.TypeArguments[0], visited); if (underlying.IsFrameworkType) { return underlying.WithNullable(true); } } - return ConstructCSharpTypeFromSymbol(typeSymbol, fullyQualifiedName, namedTypeSymbol); + return ConstructCSharpTypeFromSymbol(typeSymbol, fullyQualifiedName, namedTypeSymbol, visited); } CSharpType result = new CSharpType(type); if (namedTypeSymbol is not null && namedTypeSymbol.IsGenericType && !result.IsNullable) { - return result.MakeGenericType([.. namedTypeSymbol.TypeArguments.Select(GetCSharpType)]); + return result.MakeGenericType([.. namedTypeSymbol.TypeArguments.Select(t => GetCSharpType(t, visited))]); } return result; @@ -175,8 +178,14 @@ public static string GetFullyQualifiedNameFromDisplayString(this ISymbol typeSym private static CSharpType ConstructCSharpTypeFromSymbol( ITypeSymbol typeSymbol, string fullyQualifiedName, - INamedTypeSymbol? namedTypeSymbol) + INamedTypeSymbol? namedTypeSymbol, + HashSet visited) { + if (!visited.Add(typeSymbol)) + { + return ConstructShallowCSharpTypeFromSymbol(typeSymbol, fullyQualifiedName); + } + var typeArg = namedTypeSymbol?.TypeArguments.FirstOrDefault(); bool isValueType = typeSymbol.IsValueType; bool isNullable = fullyQualifiedName.StartsWith(NullableTypeName); @@ -193,7 +202,7 @@ private static CSharpType ConstructCSharpTypeFromSymbol( if (namedTypeSymbol?.IsGenericType == true && (!isNullable || (namedTypeArg?.IsGenericType == true))) { - arguments.AddRange(namedTypeSymbol.TypeArguments.Select(GetCSharpType)); + arguments.AddRange(namedTypeSymbol.TypeArguments.Select(t => GetCSharpType(t, visited))); } // handle nullables @@ -207,9 +216,9 @@ private static CSharpType ConstructCSharpTypeFromSymbol( string ns = string.Join('.', pieces.Take(pieces.Length - 1)); CSharpType? containingType = null; - if (typeSymbol.ContainingType != null) + if (typeSymbol.ContainingType != null && typeSymbol.TypeKind != TypeKind.TypeParameter) { - containingType = GetCSharpType(typeSymbol.ContainingType); + containingType = GetCSharpType(typeSymbol.ContainingType, visited); ns = string.Join('.', pieces.Take(pieces.Length - 2)); } @@ -219,10 +228,10 @@ private static CSharpType ConstructCSharpTypeFromSymbol( !isNullableUnknownType && !ContainsTypeAsArgument(typeSymbol.BaseType, typeSymbol)) { - baseType = GetCSharpType(typeSymbol.BaseType); + baseType = GetCSharpType(typeSymbol.BaseType, visited); } - return new CSharpType( + var result = new CSharpType( name, ns, isValueType, @@ -233,8 +242,24 @@ private static CSharpType ConstructCSharpTypeFromSymbol( isValueType && !isEnum, baseType: baseType, underlyingEnumType: enumUnderlyingType != null - ? GetCSharpType(enumUnderlyingType).FrameworkType + ? GetCSharpType(enumUnderlyingType, visited).FrameworkType : null); + visited.Remove(typeSymbol); + return result; + } + + private static CSharpType ConstructShallowCSharpTypeFromSymbol(ITypeSymbol typeSymbol, string fullyQualifiedName) + { + string[] pieces = fullyQualifiedName.Split('`')[0].Split('.'); + return new CSharpType( + typeSymbol.Name, + string.Join('.', pieces.Take(pieces.Length - 1)), + typeSymbol.IsValueType, + fullyQualifiedName.StartsWith(NullableTypeName), + null, + [], + typeSymbol.DeclaredAccessibility == Accessibility.Public, + typeSymbol.IsValueType && typeSymbol.TypeKind != TypeKind.Enum); } internal static bool ContainsTypeAsArgument(ITypeSymbol potentialGenericType, ITypeSymbol targetType) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/OutputLibraryVisitorTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/OutputLibraryVisitorTests.cs index f716aad5ab3..240dd77a792 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/OutputLibraryVisitorTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/OutputLibraryVisitorTests.cs @@ -451,7 +451,7 @@ private class TestFilterVisitor : LibraryVisitor return method; } - protected internal override ConstructorProvider? VisitConstructor(ConstructorProvider constructor) + protected override ConstructorProvider? VisitConstructor(ConstructorProvider constructor) { if (constructor.Signature.Parameters.Count > 0) { @@ -469,7 +469,7 @@ private class TestFilterVisitor : LibraryVisitor return property; } - protected internal override FieldProvider? VisitField(FieldProvider field) + protected override FieldProvider? VisitField(FieldProvider field) { if (field.Name == "TestField") { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs index 28981148a4d..8e8c10dca8d 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/PostProcessorTests.cs @@ -60,7 +60,42 @@ public async Task RemovesInvalidUsings() CollectionAssert.Contains(usings, "System"); } + [Test] + public async Task RemovesInvalidUsingsKeepsFileHeader() + { + MockHelpers.LoadMockGenerator(); + var workspace = new AdhocWorkspace(); + var projectInfo = ProjectInfo.Create( + ProjectId.CreateNewId(), + VersionStamp.Create(), + name: "TestProj", + assemblyName: "TestProj", + language: LanguageNames.CSharp) + .WithMetadataReferences(new[] + { + MetadataReference.CreateFromFile(typeof(object).Assembly.Location) + }); + + var project = workspace.AddProject(projectInfo); + var folder = Helpers.GetAssetFileOrDirectoryPath(false); + project = AddGeneratedDocument( + project, + "RootClient.cs", + "src", + "Generated", + "RootClient.cs", + File.ReadAllText(Path.Join(folder, "RootClient.cs"))); + var postProcessor = new TestPostProcessor("RootClient.cs"); + + var resultProject = await postProcessor.RemoveAsync(project); + var doc = resultProject.Documents.Single(d => d.Name == "RootClient.cs"); + var text = (await doc.GetTextAsync()).ToString(); + StringAssert.StartsWith("// Copyright (c) Microsoft Corporation. All rights reserved.", text); + StringAssert.Contains("// ", text); + StringAssert.Contains("#nullable disable", text); + StringAssert.DoesNotContain("using Missing.Namespace;", text); + } [Test] public async Task DoesNotRemoveValidUsings() { @@ -289,11 +324,14 @@ public async Task DoesNotRemoveValidAttributes() Assert.AreEqual(Helpers.GetExpectedFromFile().TrimEnd(), output, "The output should match the expected content."); } + private static Project AddGeneratedDocument(Project project, string name, string folder1, string folder2, string fileName, string text) + => project.AddDocument(name, text, folders: [folder1, folder2], filePath: Path.Join(folder1, folder2, fileName)).Project; + private class TestPostProcessor : PostProcessor { private readonly string _rootFile; - public TestPostProcessor(string rootFile, IEnumerable? nonRootTypes = null) : base([], additionalNonRootTypeNames: nonRootTypes) + public TestPostProcessor(string rootFile, IEnumerable? additionalRootTypeNames = null, IEnumerable? nonRootTypes = null, string? modelFactoryFullName = null) : base((additionalRootTypeNames ?? []).ToHashSet(), modelFactoryFullName: modelFactoryFullName, additionalNonRootTypeNames: nonRootTypes) { _rootFile = rootFile; } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/TestData/PostProcessorTests/RemovesInvalidUsingsKeepsFileHeader/RootClient.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/TestData/PostProcessorTests/RemovesInvalidUsingsKeepsFileHeader/RootClient.cs new file mode 100644 index 00000000000..572d6f590ac --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/PostProcessing/TestData/PostProcessorTests/RemovesInvalidUsingsKeepsFileHeader/RootClient.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// + +#nullable disable + +using Missing.Namespace; + +namespace Sample +{ + public partial class RootClient + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/EnumProviders/EnumProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/EnumProviders/EnumProviderTests.cs index 3bbc8cb463e..e112cec28b0 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/EnumProviders/EnumProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/EnumProviders/EnumProviderTests.cs @@ -331,7 +331,7 @@ public void ExtensibleFloatEnum_HasOnlyNonNullableImplicitOperator() } [Test] - public void PublicModelsAreIncludedInAdditionalRootTypes() + public void PublicEnumsAreNotIncludedInAdditionalRootTypes() { var inputEnum = InputFactory.StringEnum( "StringEnum", @@ -345,7 +345,7 @@ public void PublicModelsAreIncludedInAdditionalRootTypes() Assert.IsNotNull(enumProvider); var rootTypes = CodeModelGenerator.Instance.AdditionalRootTypes; - Assert.IsTrue(rootTypes.Contains("Sample.Models.StringEnum")); + Assert.IsFalse(rootTypes.Contains("Sample.Models.StringEnum")); } [Test] diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs index 1bdf4020167..c35f1968d76 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs @@ -177,6 +177,32 @@ public async Task OmitsModelFactoryMethodIfParamTypeInternal() Assert.IsNull(modelFactory); } + // This test validates that a derived model customized to be internal does not get a + // public model factory method just because its base model remains public. + [Test] + public async Task OmitsModelFactoryMethodIfDerivedModelTypeInternal() + { + var baseModel = InputFactory.Model( + "baseModel", + properties: [InputFactory.Property("BaseProp", InputPrimitiveType.String)]); + var derivedModel = InputFactory.Model( + "derivedModel", + properties: [InputFactory.Property("DerivedProp", InputPrimitiveType.String)], + baseModel: baseModel); + + var mockGenerator = await MockHelpers.LoadMockGeneratorAsync( + inputModelTypes: [baseModel, derivedModel], + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + var csharpGen = new CSharpGen(); + + await csharpGen.ExecuteAsync(); + + var modelFactory = mockGenerator.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider); + Assert.IsNotNull(modelFactory); + CollectionAssert.Contains(modelFactory!.Methods.Select(m => m.Signature.Name), "BaseModel"); + CollectionAssert.DoesNotContain(modelFactory.Methods.Select(m => m.Signature.Name), "DerivedModel"); + } + [TestCase(true)] [TestCase(false)] public async Task CanCustomizeModelFullConstructor(bool extraParameters) diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoryProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoryProviderTests.cs index 934239ba78f..45f9b0385b4 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoryProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/ModelFactoryProviderTests.cs @@ -975,61 +975,6 @@ public async Task BackCompatibility_BackCompatMethodAlreadyCustom() Assert.AreEqual(4, publicModel1Methods[0].Signature.Parameters.Count); } - // Back-compat members are synthesized in ProcessTypeForBackCompatibility, which runs after the - // main library visitor pass. This test ensures those newly-added members are still run through - // the registered visitors (only the new members, not the already-visited existing ones). - [Test] - public async Task BackCompatibility_BackCompatMethodIsVisited() - { - _instance = (await MockHelpers.LoadMockGeneratorAsync( - inputNamespaceName: "Sample.Namespace", - inputModelTypes: ModelList, - lastContractCompilation: async () => await Helpers.GetCompilationFromDirectoryAsync(method: "BackCompatibility_NewModelPropertyAdded"))).Object; - - var recordingVisitor = new RecordingMethodVisitor(); - _instance.AddVisitor(recordingVisitor); - - var modelFactory = _instance!.OutputLibrary.ModelFactory.Value; - modelFactory.ProcessTypeForBackCompatibility(); - - var backCompatMethod = modelFactory.Methods - .FirstOrDefault(m => m.Signature.Name == "PublicModel1" && m.Signature.Parameters.All(p => p.Name != "dictProp")); - Assert.IsNotNull(backCompatMethod, "Expected a back-compat overload to be synthesized."); - - // The synthesized back-compat method must have been visited. - Assert.IsTrue( - recordingVisitor.VisitedMethods.Contains(backCompatMethod!), - "The back-compat method was not visited by the library visitor."); - - // Existing methods that were already part of the contract are not re-visited by the visitor - // added after the main pass (they would have been visited during the main pass in a real run). - var currentOverloadMethod = modelFactory.Methods - .FirstOrDefault(m => m.Signature.Name == "PublicModel1" && m.Signature.Parameters.Any(p => p.Name == "dictProp")); - Assert.IsNotNull(currentOverloadMethod); - Assert.IsFalse(recordingVisitor.VisitedMethods.Contains(currentOverloadMethod!)); - } - - // Verifies that a visitor can mutate (rename) a synthesized back-compat method and the change is - // reflected in the final generated methods. - [Test] - public async Task BackCompatibility_BackCompatMethodCanBeMutatedByVisitor() - { - _instance = (await MockHelpers.LoadMockGeneratorAsync( - inputNamespaceName: "Sample.Namespace", - inputModelTypes: ModelList, - lastContractCompilation: async () => await Helpers.GetCompilationFromDirectoryAsync(method: "BackCompatibility_NewModelPropertyAdded"))).Object; - - _instance.AddVisitor(new BackCompatMethodRenamingVisitor()); - - var modelFactory = _instance!.OutputLibrary.ModelFactory.Value; - modelFactory.ProcessTypeForBackCompatibility(); - - // The visitor renames any method carrying the EditorBrowsableNever attribute (the back-compat - // overload) so the mutation must be observable on the final method collection. - var renamed = modelFactory.Methods.FirstOrDefault(m => m.Signature.Name == "PublicModel1Renamed"); - Assert.IsNotNull(renamed, "The visitor's rename of the back-compat method was not applied."); - } - private static InputModelType[] GetTestModels() { InputType additionalPropertiesUnknown = InputPrimitiveType.Any; @@ -1057,29 +1002,5 @@ private static InputModelType[] GetTestModels() InputFactory.Model("ModelWithUnknownAdditionalProperties", properties: properties, additionalProperties: additionalPropertiesUnknown), ]; } - - private class RecordingMethodVisitor : LibraryVisitor - { - public List VisitedMethods { get; } = []; - - protected internal override MethodProvider? VisitMethod(MethodProvider method) - { - VisitedMethods.Add(method); - return method; - } - } - - private class BackCompatMethodRenamingVisitor : LibraryVisitor - { - protected internal override MethodProvider? VisitMethod(MethodProvider method) - { - if (method.Signature.Name == "PublicModel1" - && method.Signature.Attributes.Any(a => a.ToDisplayString().Contains("EditorBrowsable"))) - { - method.Signature.Update(name: "PublicModel1Renamed"); - } - return method; - } - } } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs new file mode 100644 index 00000000000..bdb2034f5f0 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/OmitsModelFactoryMethodIfDerivedModelTypeInternal/DerivedModel.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Sample.Models +{ + internal partial class DerivedModel + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ClientCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ClientCustomizationTests.cs index 7f3ea5cd1ff..d96e98a9d8e 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ClientCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ClientCustomizationTests.cs @@ -405,7 +405,7 @@ private class ClientTypeProvider : TypeProvider public MethodProvider[] MethodProviders { get; set; } = []; public ConstructorProvider[] ConstructorProviders { get; set; } = []; - protected override string BuildRelativeFilePath() => "."; + protected override string BuildRelativeFilePath() => $"{Name}.cs"; protected override string BuildName() => "MockInputClient"; diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs index 9f9945a2360..a30bd15aa29 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs @@ -1524,7 +1524,7 @@ await MockHelpers.LoadMockGeneratorAsync( } [Test] - public void PublicModelsAreIncludedInAdditionalRootTypes() + public void PublicModelsAreNotIncludedInAdditionalRootTypes() { var inputModel = InputFactory.Model( "MockInputModel", @@ -1537,7 +1537,7 @@ public void PublicModelsAreIncludedInAdditionalRootTypes() Assert.IsNotNull(modelProvider); var rootTypes = CodeModelGenerator.Instance.AdditionalRootTypes; - Assert.IsTrue(rootTypes.Contains("Sample.Models.MockInputModel")); + Assert.IsFalse(rootTypes.Contains("Sample.Models.MockInputModel")); } [Test] diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/TestData/PropertyDescriptionTests/TestGetUnionTypesDescriptions.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/TestData/PropertyDescriptionTests/TestGetUnionTypesDescriptions.cs index 27c2aea8767..130f92a90fa 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/TestData/PropertyDescriptionTests/TestGetUnionTypesDescriptions.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/TestData/PropertyDescriptionTests/TestGetUnionTypesDescriptions.cs @@ -1,7 +1,7 @@ /// -/// . -/// . -/// where TKey is of type , where TValue is of type . +/// bool. +/// int. +/// global::System.Collections.Generic.IDictionary{string,int}. /// 21. /// "test". /// True. diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs new file mode 100644 index 00000000000..be99b9ae2f2 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/ReferenceMap/ProviderReferenceMapAnalyzerTests.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.TypeSpec.Generator.Primitives; +using Microsoft.TypeSpec.Generator.Providers; +using Microsoft.TypeSpec.Generator.Tests.TestHelpers; +using NUnit.Framework; + +namespace Microsoft.TypeSpec.Generator.Tests.ReferenceMap +{ + public class ProviderReferenceMapAnalyzerTests + { + [SetUp] + public void SetUp() + { + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); + } + + [TearDown] + public void TearDown() + { + ProviderReferenceMapAnalyzer.ResetPreWriteAccessibility(); + } + + [Test] + public void NonRootKeptTypesKeepTheirAccessibility() + { + var context = new TestTypeProvider("SampleContext", TypeSignatureModifiers.Public); + MockHelpers.LoadMockGenerator(createOutputLibrary: () => new TestOutputLibrary(context)); + CodeModelGenerator.Instance.AddTypeToKeep(context.Type.FullyQualifiedName, isRoot: false); + + ProviderReferenceMapAnalyzer.ApplyPreWriteAccessibility([context]); + + Assert.IsTrue(context.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + Assert.IsFalse(context.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal)); + } + + [Test] + public void NonRootKeptTypesAreWrittenWithoutRootingOtherTypes() + { + var context = new TestTypeProvider("SampleContext", TypeSignatureModifiers.Public); + var unusedModel = new TestTypeProvider("UnusedModel", TypeSignatureModifiers.Public); + MockHelpers.LoadMockGenerator(createOutputLibrary: () => new TestOutputLibrary(context, unusedModel)); + CodeModelGenerator.Instance.AddTypeToKeep(context.Type.FullyQualifiedName, isRoot: false); + + ProviderReferenceMapAnalyzer.Analyze([context, unusedModel]); + + Assert.IsTrue(ProviderReferenceMapAnalyzer.ShouldWriteProvider(context)); + Assert.IsFalse(ProviderReferenceMapAnalyzer.ShouldWriteProvider(unusedModel)); + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/TestHelpers/TestOutputLibrary.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/TestHelpers/TestOutputLibrary.cs index 00119db05dd..7490c2f41e5 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/TestHelpers/TestOutputLibrary.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/TestHelpers/TestOutputLibrary.cs @@ -14,6 +14,11 @@ public TestOutputLibrary(TypeProvider typeProvider) _types = [typeProvider]; } + public TestOutputLibrary(params TypeProvider[] typeProviders) + { + _types = typeProviders; + } + protected override TypeProvider[] BuildTypeProviders() => _types; } } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TestData/TypeSymbolExtensionsTests/TypeParameterDoesNotResolveContainingGenericType/GenericContainer.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TestData/TypeSymbolExtensionsTests/TypeParameterDoesNotResolveContainingGenericType/GenericContainer.cs new file mode 100644 index 00000000000..27bd4cf11cd --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TestData/TypeSymbolExtensionsTests/TypeParameterDoesNotResolveContainingGenericType/GenericContainer.cs @@ -0,0 +1,6 @@ +namespace Sample +{ + public class GenericContainer + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TypeSymbolExtensionsTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TypeSymbolExtensionsTests.cs index 3884f8d584e..e8db4359802 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TypeSymbolExtensionsTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Utilities/TypeSymbolExtensionsTests.cs @@ -83,6 +83,19 @@ public async Task NonNullableKnownFrameworkTypeResolvesUnchanged() Assert.IsFalse(csharpType.IsNullable); } + [Test] + public async Task TypeParameterDoesNotResolveContainingGenericType() + { + var compilation = await Helpers.GetCompilationFromDirectoryAsync(); + var typeSymbol = compilation.GetTypeByMetadataName("Sample.GenericContainer`1"); + Assert.IsNotNull(typeSymbol, "Failed to resolve generic type symbol from compiled source."); + + var csharpType = typeSymbol!.TypeParameters[0].GetCSharpType(); + + Assert.AreEqual("T", csharpType.Name); + Assert.IsNull(csharpType.DeclaringType); + } + private static IPropertySymbol GetPropertySymbol(Compilation compilation, string containerName, string propertyName) { var typeSymbol = compilation.GetTypeByMetadataName($"Sample.{containerName}"); diff --git a/packages/http-client-csharp/generator/TestProjects/Local/Sample-TypeSpec/src/Generated/Models/Thing.cs b/packages/http-client-csharp/generator/TestProjects/Local/Sample-TypeSpec/src/Generated/Models/Thing.cs index 96618c816b8..77af28e9169 100644 --- a/packages/http-client-csharp/generator/TestProjects/Local/Sample-TypeSpec/src/Generated/Models/Thing.cs +++ b/packages/http-client-csharp/generator/TestProjects/Local/Sample-TypeSpec/src/Generated/Models/Thing.cs @@ -100,13 +100,13 @@ internal Thing(string rename, BinaryData requiredUnion, string requiredLiteralSt /// Supported types: /// /// - /// . + /// string. /// /// - /// where T is of type . + /// global::System.Collections.Generic.IList{string}. /// /// - /// . + /// int. /// /// /// diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/documentation/src/Generated/DocumentationModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/documentation/src/Generated/DocumentationModelFactory.cs index 8918ac1946c..4440dc1b6a0 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/documentation/src/Generated/DocumentationModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/documentation/src/Generated/DocumentationModelFactory.cs @@ -8,7 +8,6 @@ namespace Documentation { public static partial class DocumentationModelFactory { - public static BulletPointsModel BulletPointsModel(BulletPointsEnum prop = default) => throw null; } } diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/basic/src/Generated/ParametersBasicModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/basic/src/Generated/ParametersBasicModelFactory.cs index 06d44d34bc1..c17dda5eaec 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/basic/src/Generated/ParametersBasicModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/basic/src/Generated/ParametersBasicModelFactory.cs @@ -3,7 +3,6 @@ #nullable disable using Parameters.Basic._ExplicitBody; -using Parameters.Basic._ImplicitBody; namespace Parameters.Basic { diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/spread/src/Generated/ParametersSpreadModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/spread/src/Generated/ParametersSpreadModelFactory.cs index 775c933bc6b..4e494c2f2ec 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/spread/src/Generated/ParametersSpreadModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/parameters/spread/src/Generated/ParametersSpreadModelFactory.cs @@ -2,8 +2,6 @@ #nullable disable -using System.Collections.Generic; -using Parameters.Spread._Alias; using Parameters.Spread._Model; namespace Parameters.Spread diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/multipart/src/Generated/PayloadMultiPartModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/multipart/src/Generated/PayloadMultiPartModelFactory.cs index 6d036b01c86..e5c87b994a6 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/multipart/src/Generated/PayloadMultiPartModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/multipart/src/Generated/PayloadMultiPartModelFactory.cs @@ -2,7 +2,6 @@ #nullable disable -using System; using System.ClientModel; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/pageable/src/Generated/PayloadPageableModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/pageable/src/Generated/PayloadPageableModelFactory.cs index 0b66e798a05..1b7b8eefa6f 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/pageable/src/Generated/PayloadPageableModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/payload/pageable/src/Generated/PayloadPageableModelFactory.cs @@ -2,18 +2,12 @@ #nullable disable -using System; -using System.Collections.Generic; -using Payload.Pageable._PageSize; -using Payload.Pageable._ServerDrivenPagination; using Payload.Pageable._ServerDrivenPagination.AlternateInitialVerb; -using Payload.Pageable._ServerDrivenPagination.ContinuationToken; namespace Payload.Pageable { public static partial class PayloadPageableModelFactory { - public static Pet Pet(string id = default, string name = default) => throw null; public static XmlPet XmlPet(string id = default, string name = default) => throw null; diff --git a/packages/http-client-csharp/generator/TestProjects/Spector/http/special-words/src/Generated/SpecialWordsModelFactory.cs b/packages/http-client-csharp/generator/TestProjects/Spector/http/special-words/src/Generated/SpecialWordsModelFactory.cs index 1abb2114c9a..43aca8db259 100644 --- a/packages/http-client-csharp/generator/TestProjects/Spector/http/special-words/src/Generated/SpecialWordsModelFactory.cs +++ b/packages/http-client-csharp/generator/TestProjects/Spector/http/special-words/src/Generated/SpecialWordsModelFactory.cs @@ -2,10 +2,8 @@ #nullable disable -using System.Collections.Generic; using SpecialWords._ModelProperties; using SpecialWords._Models; -using SpecialWords._ReservedOperationBodyParams; namespace SpecialWords {