13

In one of my controller+action pair, I am getting the values of another controller and action as strings from somewhere and I want to redirect my current action. Before making a redirect I want to make sure that controller+action exists in my app, if not then redirect to 404. I am looking for a way to do this.

public ActionResult MyTestAction()
{
    string controller = getFromSomewhere();
    string action = getFromSomewhereToo();

    /*
      At this point use reflection and make sure action and controller exists
      else redirect to error 404
    */ 

    return RedirectToRoute(new { action = action, controller = controller });
}

All I have done is this, but it doesn't work.

var cont = Assembly.GetExecutingAssembly().GetType(controller);
if (cont != null && cont.GetMethod(action) != null)
{ 
    // controller and action pair is valid
}
else
{ 
    // controller and action pair is invalid
}

4 Answers 4

8

You can implement IRouteConstraint and use it in your route table.

The implementation of this route constraint can than use reflection to check if controller/action exists. If it doesn't exist the route will be skipped. As a last route in your route table, you can set one that catches all and map it to action that renders 404 view.

Here's some code snippet to help you started:

public class MyRouteConstraint : IRouteConstraint
    {
        public bool Match(HttpContextBase httpContext, Route route, string parameterName, RouteValueDictionary values, RouteDirection routeDirection)
        {

            var action = values["action"] as string;
            var controller = values["controller"] as string;

            var controllerFullName = string.Format("MvcApplication1.Controllers.{0}Controller", controller);

            var cont = Assembly.GetExecutingAssembly().GetType(controllerFullName);

            return cont != null && cont.GetMethod(action) != null;
        }
    }

Note that you need to use fully-qualified name of the controller.

RouteConfig.cs

routes.MapRoute(
                "Home", // Route name
                "{controller}/{action}", // URL with parameters
                new { controller = "Home", action = "Index" }, // Parameter defaults
                new { action = new MyRouteConstraint() } //Route constraints
            );

routes.MapRoute(
                "PageNotFound", // Route name
                "{*catchall}", // URL with parameters
                new { controller = "Home", action = "PageNotFound" } // Parameter defaults
            );
Sign up to request clarification or add additional context in comments.

Comments

5

If you can't obtain the fully-qualified name of the controller to pass in to GetType() you'll need to use GetTypes() and then do a string comparison over the results.

Type[] types = System.Reflection.Assembly.GetExecutingAssembly().GetTypes();

Type type = types.Where( t => t.Name == controller ).SingleOrDefault();

if( type != null && type.GetMethod( action ) != null )

1 Comment

If the controller has two matches for the action (for example a HTTPGet and a HTTPPost), you will get an error here. It would be safer to change the last line to: if( type != null && !type.GetMethods().Where(p => p.Name == action).ToList().Any())>
0

We solved this by adding this line to our WebApiConfig.cs file

config.Services.Replace(typeof(IHttpControllerSelector), new AcceptHeaderControllerSelector(config));

The core method that I have used is as follows This method was within the AcceptHeaderControllerSelector class that extended the IHttpControllerSelector interface.

The reason I have done it like this is that we have to version our API and this was a way to create a new controller e.g. V2 with just the methods that we were versioning and just drop back to V1 if a V2 didn't exists

private HttpControllerDescriptor TryGetControllerWithMatchingMethod(string version, string controllerName, string actionName)
{
    var versionNumber = Convert.ToInt32(version.Substring(1, version.Length - 1));
    while(versionNumber >= 1) { 
        var controllerFullName = string.Format("Namespace.Controller.V{0}.{1}Controller, Namespace.Controller.V{0}", versionNumber, controllerName);
        Type type = Type.GetType(controllerFullName, false, true);

        var matchFound = type != null && type.GetMethod(actionName) != null;

        if (matchFound)
        {
            var key = string.Format(CultureInfo.InvariantCulture, "V{0}{1}", versionNumber, controllerName);
            HttpControllerDescriptor controllerDescriptor;
            if (_controllers.TryGetValue(key, out controllerDescriptor))
            {
                return controllerDescriptor;
            }
        }
        versionNumber--;
    }

    return null;
}

The full file can be seen below:

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Reflection;
using System.Web.Http;
using System.Web.Http.Controllers;
using System.Web.Http.Dispatcher;
using System.Web.Http.Routing;

namespace WebApi
{

    public class AcceptHeaderControllerSelector : IHttpControllerSelector
    {
        private const string ControllerKey = "controller";
        private const string ActionKey = "action";
        private const string VersionHeaderValueNotFoundExceptionMessage = "Version not found in headers";
        private const string VersionFoundInUrlAndHeaderErrorMessage = "Version can not be in Header and Url";
        private const string CouldNotFindEndPoint = "Could not find endpoint {0} for api version number {1}";
        private readonly HttpConfiguration _configuration;
        private readonly Dictionary<string, HttpControllerDescriptor> _controllers;

        public AcceptHeaderControllerSelector(HttpConfiguration config)
        {
            _configuration = config;
            _controllers = InitializeControllerDictionary();
        }

        private Dictionary<string, HttpControllerDescriptor> InitializeControllerDictionary()
        {
            var dictionary = new Dictionary<string, HttpControllerDescriptor>(StringComparer.OrdinalIgnoreCase);

            var assembliesResolver = _configuration.Services.GetAssembliesResolver();
            // This would seem to look at all references in the web api project and find any controller, so I had to add Controller.V2 reference in order for it to find them
            var controllersResolver = _configuration.Services.GetHttpControllerTypeResolver();

            var controllerTypes = controllersResolver.GetControllerTypes(assembliesResolver);

            foreach (var t in controllerTypes)
            {
                var segments = t.Namespace.Split(Type.Delimiter);

                // For the dictionary key, strip "Controller" from the end of the type name.
                // This matches the behavior of DefaultHttpControllerSelector.
                var controllerName = t.Name.Remove(t.Name.Length - DefaultHttpControllerSelector.ControllerSuffix.Length);

                var key = string.Format(CultureInfo.InvariantCulture, "{0}{1}", segments[segments.Length - 1], controllerName);

                dictionary[key] = new HttpControllerDescriptor(_configuration, t.Name, t);
            }

            return dictionary;
        }

        public HttpControllerDescriptor SelectController(HttpRequestMessage request)
        {
            IHttpRouteData routeData = request.GetRouteData();

            if (routeData == null)  
            {
                throw new HttpResponseException(HttpStatusCode.NotFound);
            }

            var controllerName = GetRouteVariable<string>(routeData, ControllerKey);
            var actionName = GetRouteVariable<string>(routeData, ActionKey);

            if (controllerName == null)
            {
                throw new HttpResponseException(HttpStatusCode.NotFound);
            }

            var version = GetVersion(request);

            HttpControllerDescriptor controllerDescriptor;

            if (_controllers.TryGetValue(controllerName, out controllerDescriptor))
            {
                if (!string.IsNullOrWhiteSpace(version))
                {
                    throw new HttpResponseException(request.CreateResponse(HttpStatusCode.Forbidden, VersionFoundInUrlAndHeaderErrorMessage));
                }

                return controllerDescriptor;
            }

            controllerDescriptor = TryGetControllerWithMatchingMethod(version, controllerName, actionName);

            if (controllerDescriptor != null)
            {
                return controllerDescriptor;
            }

            var message = string.Format(CouldNotFindEndPoint, controllerName, version);

            throw new HttpResponseException(request.CreateResponse(HttpStatusCode.NotFound, message));
        }

        private HttpControllerDescriptor TryGetControllerWithMatchingMethod(string version, string controllerName, string actionName)
        {
            var versionNumber = Convert.ToInt32(version.Substring(1, version.Length - 1));
            while(versionNumber >= 1) { 
                var controllerFullName = string.Format("Namespace.Controller.V{0}.{1}Controller, Namespace.Controller.V{0}", versionNumber, controllerName);
                Type type = Type.GetType(controllerFullName, false, true);

                var matchFound = type != null && type.GetMethod(actionName) != null;

                if (matchFound)
                {
                    var key = string.Format(CultureInfo.InvariantCulture, "V{0}{1}", versionNumber, controllerName);
                    HttpControllerDescriptor controllerDescriptor;
                    if (_controllers.TryGetValue(key, out controllerDescriptor))
                    {
                        return controllerDescriptor;
                    }
                }
                versionNumber--;
            }

            return null;
        }

        public IDictionary<string, HttpControllerDescriptor> GetControllerMapping()
        {
            return _controllers;
        }

        private string GetVersion(HttpRequestMessage request)
        {
            IEnumerable<string> values;
            string apiVersion = null;

            if (request.Headers.TryGetValues(Common.Classes.Constants.ApiVersion, out values))
            {
                apiVersion = values.FirstOrDefault();
            }

            return apiVersion;
        }

        private static T GetRouteVariable<T>(IHttpRouteData routeData, string name)
        {
            object result = null;
            if (routeData.Values.TryGetValue(name, out result))
            {
                return (T)result;
            }
            return default(T);
        }
    }
}

Comments

-1

Reflection is a costly operation.

You should really be unit testing these methods to ensure they are redirecting to the appropriate action and controller.

E.g. (NUnit)

[Test]
public void MyTestAction_Redirects_To_MyOtherAction()
{
  var controller = new MyController();

  var result = (RedirectToRouteResult)controller.MyTestAction();

  Assert.That(result.RouteValues["action"], Is.EqualTo("MyOtherAction");
  Assert.That(result.RouteValues["controller"], Is.EqualTo("MyOtherController");
}

1 Comment

Well, since a RouteContraint object is create once, you could reflect over all the controllers and actions once and save the results.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.