3333#include < libyul/AST.h>
3434
3535#include < boost/algorithm/string.hpp>
36+
3637#include < range/v3/view/transform.hpp>
3738
3839using namespace solidity ;
@@ -166,6 +167,13 @@ bool TypeInference::visit(FunctionDefinition const& _functionDefinition)
166167 return false ;
167168}
168169
170+ void TypeInference::endVisit (FunctionDefinition const & _functionDefinition)
171+ {
172+ solAssert (m_expressionContext == ExpressionContext::Term);
173+
174+ m_env->fixTypeVars (TypeEnvironmentHelpers{*m_env}.typeVars (type (_functionDefinition)));
175+ }
176+
169177void TypeInference::endVisit (Return const & _return)
170178{
171179 solAssert (m_currentFunctionType);
@@ -204,6 +212,8 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
204212 solAssert (m_analysis.annotation <TypeClassRegistration>(_typeClassDefinition).typeClass .has_value ());
205213 TypeClass typeClass = m_analysis.annotation <TypeClassRegistration>(_typeClassDefinition).typeClass .value ();
206214 Type typeVar = m_typeSystem.typeClassVariable (typeClass);
215+ unify (type (_typeClassDefinition.typeVariable ()), typeVar, _typeClassDefinition.location ());
216+
207217 auto & typeMembersAnnotation = annotation ().members [typeConstructor (&_typeClassDefinition)];
208218
209219 for (auto subNode: _typeClassDefinition.subNodes ())
@@ -235,7 +245,6 @@ bool TypeInference::visit(TypeClassDefinition const& _typeClassDefinition)
235245 m_errorReporter.typeError (1807_error, _typeClassDefinition.location (), " Function " + functionName + " depends on invalid type variable." );
236246 }
237247
238- unify (type (_typeClassDefinition.typeVariable ()), m_typeSystem.freshTypeVariable ({{typeClass}}), _typeClassDefinition.location ());
239248 for (auto instantiation: m_analysis.annotation <TypeRegistration>(_typeClassDefinition).instantiations | ranges::views::values)
240249 // TODO: recursion-safety? Order of instantiation?
241250 instantiation->accept (*this );
@@ -663,6 +672,7 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
663672 }) | ranges::to<std::vector<Sort>>;
664673 }
665674 }
675+ m_env->fixTypeVars (arguments);
666676
667677 Type instanceType{TypeConstant{*typeConstructor, arguments}};
668678
@@ -682,12 +692,8 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
682692
683693 auto const & classFunctions = annotation ().typeClassFunctions .at (typeClass);
684694
685- TypeEnvironment newEnv = m_env->clone ();
686- if (!newEnv.unify (m_typeSystem.typeClassVariable (typeClass), instanceType).empty ())
687- {
688- m_errorReporter.typeError (4686_error, _typeClassInstantiation.location (), " Unification of class and instance variable failed." );
689- return false ;
690- }
695+ solAssert (std::holds_alternative<TypeVariable>(m_typeSystem.typeClassVariable (typeClass)));
696+ TypeVariable classVar = std::get<TypeVariable>(m_typeSystem.typeClassVariable (typeClass));
691697
692698 for (auto [name, classFunctionType]: classFunctions)
693699 {
@@ -696,11 +702,22 @@ bool TypeInference::visit(TypeClassInstantiation const& _typeClassInstantiation)
696702 m_errorReporter.typeError (6948_error, _typeClassInstantiation.location (), " Missing function: " + name);
697703 continue ;
698704 }
705+ Type instantiatedClassFunctionType = TypeEnvironmentHelpers{*m_env}.substitute (classFunctionType, classVar, instanceType);
706+
699707 Type instanceFunctionType = functionTypes.at (name);
700708 functionTypes.erase (name);
701709
702- if (!newEnv.typeEquals (instanceFunctionType, classFunctionType))
703- m_errorReporter.typeError (7428_error, _typeClassInstantiation.location (), " Type mismatch for function " + name + " " + TypeEnvironmentHelpers{newEnv}.typeToString (instanceFunctionType) + " != " + TypeEnvironmentHelpers{newEnv}.typeToString (classFunctionType));
710+ if (!m_env->typeEquals (instanceFunctionType, instantiatedClassFunctionType))
711+ m_errorReporter.typeError (
712+ 7428_error,
713+ _typeClassInstantiation.location (),
714+ fmt::format (
715+ " Instantiation function '{}' does not match the declaration in the type class ({} != {})." ,
716+ name,
717+ TypeEnvironmentHelpers{*m_env}.typeToString (instanceFunctionType),
718+ TypeEnvironmentHelpers{*m_env}.typeToString (instantiatedClassFunctionType)
719+ )
720+ );
704721 }
705722
706723 if (!functionTypes.empty ())
@@ -760,12 +777,21 @@ bool TypeInference::visit(TypeDefinition const& _typeDefinition)
760777 return false ;
761778
762779 if (_typeDefinition.arguments ())
780+ {
781+ ScopedSaveAndRestore expressionContext{m_expressionContext, ExpressionContext::Type};
763782 _typeDefinition.arguments ()->accept (*this );
783+ }
764784
765785 std::vector<Type> arguments;
766786 if (_typeDefinition.arguments ())
767- for (size_t i = 0 ; i < _typeDefinition.arguments ()->parameters ().size (); ++i)
768- arguments.emplace_back (m_typeSystem.freshTypeVariable ({}));
787+ for (ASTPointer<VariableDeclaration> argumentDeclaration: _typeDefinition.arguments ()->parameters ())
788+ {
789+ solAssert (argumentDeclaration);
790+ Type typeVar = type (*argumentDeclaration);
791+ solAssert (std::holds_alternative<TypeVariable>(typeVar));
792+ arguments.emplace_back (typeVar);
793+ }
794+ m_env->fixTypeVars (arguments);
769795
770796 Type definedType = type (&_typeDefinition, arguments);
771797 if (arguments.empty ())
@@ -791,6 +817,9 @@ bool TypeInference::visit(TypeDefinition const& _typeDefinition)
791817 solAssert (newlyInserted, fmt::format (" Members of type '{}' are already defined." , m_typeSystem.constructorInfo (constructor).name ));
792818 if (underlyingType)
793819 {
820+ // Undeclared type variables are not allowed in type definitions and we fixed all the declared ones.
821+ solAssert (!TypeEnvironmentHelpers{*m_env}.hasGenericTypeVars (*underlyingType));
822+
794823 members->second .emplace (" abs" , TypeMember{helper.functionType (*underlyingType, definedType)});
795824 members->second .emplace (" rep" , TypeMember{helper.functionType (definedType, *underlyingType)});
796825 }
0 commit comments